Skip to content

API Reference

This page contains the API documentation for all Python modules in the codebase (excluding init.py files).

aiperf.cli

Main CLI entry point for the AIPerf system.

analyze(user_config, service_config=None)

Sweep through one or more parameters.

Source code in aiperf/cli.py
42
43
44
45
46
47
48
49
50
51
@app.command(name="analyze")
def analyze(
    user_config: UserConfig,
    service_config: ServiceConfig | None = None,
) -> None:
    """Sweep through one or more parameters."""
    # TODO: Implement this
    from aiperf.cli_runner import warn_command_not_implemented

    warn_command_not_implemented("analyze")

create_template(template_filename=CLIDefaults.TEMPLATE_FILENAME)

Create a template configuration file.

Source code in aiperf/cli.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@app.command(name="create-template")
def create_template(
    template_filename: Annotated[
        str,
        Field(
            description=f"Path to the template file. Defaults to {CLIDefaults.TEMPLATE_FILENAME}."
        ),
        cyclopts.Parameter(
            name=("--template-filename", "-t"),
        ),
    ] = CLIDefaults.TEMPLATE_FILENAME,
) -> None:
    """Create a template configuration file."""
    # TODO: Implement this
    from aiperf.cli_runner import warn_command_not_implemented

    warn_command_not_implemented("create-template")

profile(user_config, service_config=None)

Run the Profile subcommand.

Parameters:

Name Type Description Default
user_config UserConfig

User configuration for the benchmark

required
service_config ServiceConfig | None

Service configuration options

None
Source code in aiperf/cli.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@app.command(name="profile")
def profile(
    user_config: UserConfig,
    service_config: ServiceConfig | None = None,
) -> None:
    """Run the Profile subcommand.

    Args:
        user_config: User configuration for the benchmark
        service_config: Service configuration options
    """
    from aiperf.cli_runner import run_system_controller
    from aiperf.common.config import load_service_config

    service_config = service_config or load_service_config()

    run_system_controller(user_config, service_config)

validate_config(user_config=None, service_config=None)

Validate the configuration file.

Source code in aiperf/cli.py
73
74
75
76
77
78
79
80
81
82
@app.command(name="validate-config")
def validate_config(
    user_config: UserConfig | None = None,
    service_config: ServiceConfig | None = None,
) -> None:
    """Validate the configuration file."""
    # TODO: Implement this
    from aiperf.cli_runner import warn_command_not_implemented

    warn_command_not_implemented("validate-config")

aiperf.cli_runner

run_system_controller(user_config, service_config)

Run the system controller with the given configuration.

Source code in aiperf/cli_runner.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def run_system_controller(
    user_config: UserConfig,
    service_config: ServiceConfig,
) -> None:
    """Run the system controller with the given configuration."""

    from aiperf.common.aiperf_logger import AIPerfLogger
    from aiperf.common.bootstrap import bootstrap_and_run_service
    from aiperf.services import SystemController

    logger = AIPerfLogger(__name__)

    log_queue = None
    if service_config.disable_ui:
        from aiperf.common.logging import setup_rich_logging

        setup_rich_logging(user_config, service_config)

    # Create and start the system controller
    logger.info("Starting AIPerf System")

    try:
        bootstrap_and_run_service(
            SystemController,
            service_id="system_controller",
            service_config=service_config,
            user_config=user_config,
            log_queue=log_queue,
        )
    except Exception:
        logger.exception("Error starting AIPerf System")
        raise
    finally:
        logger.info("AIPerf System exited")

warn_command_not_implemented(command)

Warn the user that the subcommand is not implemented.

Source code in aiperf/cli_runner.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def warn_command_not_implemented(command: str) -> None:
    """Warn the user that the subcommand is not implemented."""
    import sys

    from rich.console import Console
    from rich.panel import Panel

    console = Console()
    console.print(
        Panel(
            f"Command [bold red]{command}[/bold red] is not yet implemented",
            title="Error",
            title_align="left",
            border_style="red",
        )
    )

    sys.exit(1)

aiperf.clients.client_interfaces

InferenceClientFactory

Bases: FactoryMixin[EndpointType, InferenceClientProtocol]

Factory for registering and creating InferenceClientProtocol instances based on the specified endpoint type. see: :class:FactoryMixin for more details.

Source code in aiperf/clients/client_interfaces.py
56
57
58
59
class InferenceClientFactory(FactoryMixin[EndpointType, InferenceClientProtocol]):
    """Factory for registering and creating InferenceClientProtocol instances based on the specified endpoint type.
    see: :class:`FactoryMixin` for more details.
    """

InferenceClientProtocol

Bases: Protocol, Generic[RequestInputT]

Protocol for an inference server client.

This protocol defines the methods that must be implemented by any inference server client implementation that is compatible with the AIPerf framework.

Source code in aiperf/clients/client_interfaces.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@runtime_checkable
class InferenceClientProtocol(Protocol, Generic[RequestInputT]):
    """Protocol for an inference server client.

    This protocol defines the methods that must be implemented by any inference server client
    implementation that is compatible with the AIPerf framework.
    """

    def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
        """Create a new inference server client based on the provided configuration."""
        ...

    async def initialize(self) -> None:
        """Initialize the inference server client in an asynchronous context."""
        ...

    async def send_request(
        self,
        model_endpoint: ModelEndpointInfo,
        payload: RequestInputT,
    ) -> RequestRecord:
        """Send a request to the inference server.

        This method is used to send a request to the inference server.

        Args:
            model_endpoint: The endpoint to send the request to.
            payload: The payload to send to the inference server.
        Returns:
            The raw response from the inference server.
        """
        ...

    async def close(self) -> None:
        """Close the client."""
        ...

__init__(model_endpoint)

Create a new inference server client based on the provided configuration.

Source code in aiperf/clients/client_interfaces.py
26
27
28
def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
    """Create a new inference server client based on the provided configuration."""
    ...

close() async

Close the client.

Source code in aiperf/clients/client_interfaces.py
51
52
53
async def close(self) -> None:
    """Close the client."""
    ...

initialize() async

Initialize the inference server client in an asynchronous context.

Source code in aiperf/clients/client_interfaces.py
30
31
32
async def initialize(self) -> None:
    """Initialize the inference server client in an asynchronous context."""
    ...

send_request(model_endpoint, payload) async

Send a request to the inference server.

This method is used to send a request to the inference server.

Parameters:

Name Type Description Default
model_endpoint ModelEndpointInfo

The endpoint to send the request to.

required
payload RequestInputT

The payload to send to the inference server.

required

Returns: The raw response from the inference server.

Source code in aiperf/clients/client_interfaces.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
async def send_request(
    self,
    model_endpoint: ModelEndpointInfo,
    payload: RequestInputT,
) -> RequestRecord:
    """Send a request to the inference server.

    This method is used to send a request to the inference server.

    Args:
        model_endpoint: The endpoint to send the request to.
        payload: The payload to send to the inference server.
    Returns:
        The raw response from the inference server.
    """
    ...

RequestConverterFactory

Bases: FactoryMixin[EndpointType, RequestConverterProtocol]

Factory for registering and creating RequestConverterProtocol instances based on the specified request payload type. see: :class:FactoryMixin for more details.

Source code in aiperf/clients/client_interfaces.py
78
79
80
81
class RequestConverterFactory(FactoryMixin[EndpointType, RequestConverterProtocol]):
    """Factory for registering and creating RequestConverterProtocol instances based on the specified request payload type.
    see: :class:`FactoryMixin` for more details.
    """

RequestConverterProtocol

Bases: Protocol, Generic[RequestOutputT]

Protocol for a request converter that converts a raw request to a formatted request for the inference server.

Source code in aiperf/clients/client_interfaces.py
67
68
69
70
71
72
73
74
75
@runtime_checkable
class RequestConverterProtocol(Protocol, Generic[RequestOutputT]):
    """Protocol for a request converter that converts a raw request to a formatted request for the inference server."""

    async def format_payload(
        self, model_endpoint: ModelEndpointInfo, turn: Turn
    ) -> RequestOutputT:
        """Format the turn for the inference server."""
        ...

format_payload(model_endpoint, turn) async

Format the turn for the inference server.

Source code in aiperf/clients/client_interfaces.py
71
72
73
74
75
async def format_payload(
    self, model_endpoint: ModelEndpointInfo, turn: Turn
) -> RequestOutputT:
    """Format the turn for the inference server."""
    ...

ResponseExtractorFactory

Bases: FactoryMixin[EndpointType, ResponseExtractorProtocol]

Factory for registering and creating ResponseExtractorProtocol instances based on the specified response extractor type. see: :class:FactoryMixin for more details.

Source code in aiperf/clients/client_interfaces.py
101
102
103
104
class ResponseExtractorFactory(FactoryMixin[EndpointType, ResponseExtractorProtocol]):
    """Factory for registering and creating ResponseExtractorProtocol instances based on the specified response extractor type.
    see: :class:`FactoryMixin` for more details.
    """

ResponseExtractorProtocol

Bases: Protocol

Protocol for a response extractor that extracts the response data from a raw inference server response and converts it to a list of ResponseData objects.

Source code in aiperf/clients/client_interfaces.py
89
90
91
92
93
94
95
96
97
98
@runtime_checkable
class ResponseExtractorProtocol(Protocol):
    """Protocol for a response extractor that extracts the response data from a raw inference server
    response and converts it to a list of ResponseData objects."""

    async def extract_response_data(
        self, record: RequestRecord, tokenizer: Tokenizer | None
    ) -> list[ResponseData]:
        """Extract the response data from a raw inference server response and convert it to a list of ResponseData objects."""
        ...

extract_response_data(record, tokenizer) async

Extract the response data from a raw inference server response and convert it to a list of ResponseData objects.

Source code in aiperf/clients/client_interfaces.py
94
95
96
97
98
async def extract_response_data(
    self, record: RequestRecord, tokenizer: Tokenizer | None
) -> list[ResponseData]:
    """Extract the response data from a raw inference server response and convert it to a list of ResponseData objects."""
    ...

aiperf.clients.http.aiohttp_client

AioHttpClientMixin

A high-performance HTTP client for communicating with HTTP based REST APIs using aiohttp.

This class is optimized for maximum performance and accurate timing measurements, making it ideal for benchmarking scenarios.

Source code in aiperf/clients/http/aiohttp_client.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class AioHttpClientMixin:
    """A high-performance HTTP client for communicating with HTTP based REST APIs using aiohttp.

    This class is optimized for maximum performance and accurate timing measurements,
    making it ideal for benchmarking scenarios.
    """

    def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)
        self.model_endpoint = model_endpoint
        self.tcp_connector = create_tcp_connector()

        # For now, just set all timeouts to the same value.
        # TODO: Add support for different timeouts for different parts of the request.
        self.timeout = aiohttp.ClientTimeout(
            total=self.model_endpoint.endpoint.timeout,
            connect=self.model_endpoint.endpoint.timeout,
            sock_connect=self.model_endpoint.endpoint.timeout,
            sock_read=self.model_endpoint.endpoint.timeout,
            ceil_threshold=self.model_endpoint.endpoint.timeout,
        )

    async def close(self) -> None:
        """Close the client."""
        if self.tcp_connector:
            await self.tcp_connector.close()
            self.tcp_connector = None

    async def post_request(
        self,
        url: str,
        payload: str,
        headers: dict[str, str],
        **kwargs: Any,
    ) -> RequestRecord:
        """Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

        If the response is an SSE stream, the response will be parsed into a list of SSE messages.
        Otherwise, the response will be parsed into a TextResponse object.
        """

        self.logger.debug("Sending POST request to %s", url)

        record: RequestRecord = RequestRecord(
            start_perf_ns=time.perf_counter_ns(),
        )

        try:
            # Make raw HTTP request with precise timing using aiohttp
            async with aiohttp.ClientSession(
                connector=self.tcp_connector,
                timeout=self.timeout,
                headers=headers,
                skip_auto_headers=[
                    *list(headers.keys()),
                    "User-Agent",
                    "Accept-Encoding",
                ],
                connector_owner=False,
            ) as session:
                record.start_perf_ns = time.perf_counter_ns()
                async with session.post(
                    url, data=payload, headers=headers, **kwargs
                ) as response:
                    record.status = response.status
                    # Check for HTTP errors
                    if response.status != 200:
                        error_text = await response.text()
                        record.error = ErrorDetails(
                            code=response.status,
                            type=response.reason,
                            message=error_text,
                        )
                        return record

                    record.recv_start_perf_ns = time.perf_counter_ns()

                    if response.content_type == "text/event-stream":
                        # Parse SSE stream with optimal performance
                        messages = await AioHttpSSEStreamReader(
                            response
                        ).read_complete_stream()
                        record.responses.extend(messages)
                    else:
                        raw_response = await response.text()
                        record.end_perf_ns = time.perf_counter_ns()
                        record.responses.append(
                            TextResponse(
                                perf_ns=record.end_perf_ns,
                                content_type=response.content_type,
                                text=raw_response,
                            )
                        )
                    record.end_perf_ns = time.perf_counter_ns()

        except Exception as e:
            record.end_perf_ns = time.perf_counter_ns()
            self.logger.error("Error in aiohttp request: %s", str(e))
            record.error = ErrorDetails(type=e.__class__.__name__, message=str(e))

        return record

close() async

Close the client.

Source code in aiperf/clients/http/aiohttp_client.py
50
51
52
53
54
async def close(self) -> None:
    """Close the client."""
    if self.tcp_connector:
        await self.tcp_connector.close()
        self.tcp_connector = None

post_request(url, payload, headers, **kwargs) async

Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

If the response is an SSE stream, the response will be parsed into a list of SSE messages. Otherwise, the response will be parsed into a TextResponse object.

Source code in aiperf/clients/http/aiohttp_client.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
async def post_request(
    self,
    url: str,
    payload: str,
    headers: dict[str, str],
    **kwargs: Any,
) -> RequestRecord:
    """Send a streaming or non-streaming POST request to the specified URL with the given payload and headers.

    If the response is an SSE stream, the response will be parsed into a list of SSE messages.
    Otherwise, the response will be parsed into a TextResponse object.
    """

    self.logger.debug("Sending POST request to %s", url)

    record: RequestRecord = RequestRecord(
        start_perf_ns=time.perf_counter_ns(),
    )

    try:
        # Make raw HTTP request with precise timing using aiohttp
        async with aiohttp.ClientSession(
            connector=self.tcp_connector,
            timeout=self.timeout,
            headers=headers,
            skip_auto_headers=[
                *list(headers.keys()),
                "User-Agent",
                "Accept-Encoding",
            ],
            connector_owner=False,
        ) as session:
            record.start_perf_ns = time.perf_counter_ns()
            async with session.post(
                url, data=payload, headers=headers, **kwargs
            ) as response:
                record.status = response.status
                # Check for HTTP errors
                if response.status != 200:
                    error_text = await response.text()
                    record.error = ErrorDetails(
                        code=response.status,
                        type=response.reason,
                        message=error_text,
                    )
                    return record

                record.recv_start_perf_ns = time.perf_counter_ns()

                if response.content_type == "text/event-stream":
                    # Parse SSE stream with optimal performance
                    messages = await AioHttpSSEStreamReader(
                        response
                    ).read_complete_stream()
                    record.responses.extend(messages)
                else:
                    raw_response = await response.text()
                    record.end_perf_ns = time.perf_counter_ns()
                    record.responses.append(
                        TextResponse(
                            perf_ns=record.end_perf_ns,
                            content_type=response.content_type,
                            text=raw_response,
                        )
                    )
                record.end_perf_ns = time.perf_counter_ns()

    except Exception as e:
        record.end_perf_ns = time.perf_counter_ns()
        self.logger.error("Error in aiohttp request: %s", str(e))
        record.error = ErrorDetails(type=e.__class__.__name__, message=str(e))

    return record

AioHttpSSEStreamReader

A helper class for reading an SSE stream from an aiohttp.ClientResponse object.

This class is optimized for maximum performance and accurate timing measurements, making it ideal for benchmarking scenarios.

Source code in aiperf/clients/http/aiohttp_client.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class AioHttpSSEStreamReader:
    """A helper class for reading an SSE stream from an aiohttp.ClientResponse object.

    This class is optimized for maximum performance and accurate timing measurements,
    making it ideal for benchmarking scenarios.
    """

    def __init__(self, response: aiohttp.ClientResponse):
        self.response = response

    async def read_complete_stream(self) -> list[SSEMessage]:
        """Read the complete SSE stream in a performant manner and return a list of
        SSE messages that contain the most accurate timestamp data possible.

        Returns:
            A list of SSE messages.
        """
        messages: list[SSEMessage] = []

        async for raw_message, first_byte_ns in self.__aiter__():
            # Parse the raw SSE message into a SSEMessage object
            message = parse_sse_message(raw_message, first_byte_ns)
            messages.append(message)

        return messages

    async def __aiter__(self) -> typing.AsyncIterator[tuple[str, int]]:
        """Iterate over the SSE stream in a performant manner and return a tuple of the
        raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte.
        This provides the most accurate timing information possible without any delays due to the nature of
        the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte,
        and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

        Returns:
            An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte
        """

        while not self.response.content.at_eof():
            # Read the first byte of the SSE stream
            first_byte = await self.response.content.read(1)
            chunk_ns_first_byte = time.perf_counter_ns()
            if not first_byte:
                break

            chunk = await self.response.content.readuntil(b"\n\n")

            if not chunk:
                break
            chunk = first_byte + chunk

            try:
                # Use the fastest available decoder
                yield (
                    chunk.decode("utf-8").strip(),
                    chunk_ns_first_byte,
                )
            except UnicodeDecodeError:
                # Handle potential encoding issues gracefully
                yield (
                    chunk.decode("utf-8", errors="replace").strip(),
                    chunk_ns_first_byte,
                )

__aiter__() async

Iterate over the SSE stream in a performant manner and return a tuple of the raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte. This provides the most accurate timing information possible without any delays due to the nature of the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte, and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

Returns:

Type Description
AsyncIterator[tuple[str, int]]

An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte

Source code in aiperf/clients/http/aiohttp_client.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
async def __aiter__(self) -> typing.AsyncIterator[tuple[str, int]]:
    """Iterate over the SSE stream in a performant manner and return a tuple of the
    raw SSE message, the perf_counter_ns of the first byte, and the perf_counter_ns of the last byte.
    This provides the most accurate timing information possible without any delays due to the nature of
    the aiohttp library. The first byte is read immediately to capture the timestamp of the first byte,
    and the last byte is read after the rest of the chunk is read to capture the timestamp of the last byte.

    Returns:
        An async iterator of tuples of the raw SSE message, and the perf_counter_ns of the first byte
    """

    while not self.response.content.at_eof():
        # Read the first byte of the SSE stream
        first_byte = await self.response.content.read(1)
        chunk_ns_first_byte = time.perf_counter_ns()
        if not first_byte:
            break

        chunk = await self.response.content.readuntil(b"\n\n")

        if not chunk:
            break
        chunk = first_byte + chunk

        try:
            # Use the fastest available decoder
            yield (
                chunk.decode("utf-8").strip(),
                chunk_ns_first_byte,
            )
        except UnicodeDecodeError:
            # Handle potential encoding issues gracefully
            yield (
                chunk.decode("utf-8", errors="replace").strip(),
                chunk_ns_first_byte,
            )

read_complete_stream() async

Read the complete SSE stream in a performant manner and return a list of SSE messages that contain the most accurate timestamp data possible.

Returns:

Type Description
list[SSEMessage]

A list of SSE messages.

Source code in aiperf/clients/http/aiohttp_client.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
async def read_complete_stream(self) -> list[SSEMessage]:
    """Read the complete SSE stream in a performant manner and return a list of
    SSE messages that contain the most accurate timestamp data possible.

    Returns:
        A list of SSE messages.
    """
    messages: list[SSEMessage] = []

    async for raw_message, first_byte_ns in self.__aiter__():
        # Parse the raw SSE message into a SSEMessage object
        message = parse_sse_message(raw_message, first_byte_ns)
        messages.append(message)

    return messages

create_tcp_connector(**kwargs)

Create a new connector with the given configuration.

Source code in aiperf/clients/http/aiohttp_client.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def create_tcp_connector(**kwargs) -> aiohttp.TCPConnector:
    """Create a new connector with the given configuration."""

    def socket_factory(addr_info):
        """Custom socket factory optimized for SSE streaming performance."""
        family, sock_type, proto, _, _ = addr_info
        sock = socket.socket(family=family, type=sock_type, proto=proto)
        SocketDefaults.apply_to_socket(sock)
        return sock

    default_kwargs: dict[str, Any] = {
        "limit": AioHttpDefaults.LIMIT,
        "limit_per_host": AioHttpDefaults.LIMIT_PER_HOST,
        "ttl_dns_cache": AioHttpDefaults.TTL_DNS_CACHE,
        "use_dns_cache": AioHttpDefaults.USE_DNS_CACHE,
        "enable_cleanup_closed": AioHttpDefaults.ENABLE_CLEANUP_CLOSED,
        "force_close": AioHttpDefaults.FORCE_CLOSE,
        "keepalive_timeout": AioHttpDefaults.KEEPALIVE_TIMEOUT,
        "happy_eyeballs_delay": AioHttpDefaults.HAPPY_EYEBALLS_DELAY,
        "family": AioHttpDefaults.SOCKET_FAMILY,
        "socket_factory": socket_factory,
    }

    default_kwargs.update(kwargs)

    return aiohttp.TCPConnector(
        **default_kwargs,
    )

parse_sse_message(raw_message, perf_ns)

Parse a raw SSE message into an SSEMessage object.

Parsing logic based on official HTML SSE Living Standard: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream

Source code in aiperf/clients/http/aiohttp_client.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def parse_sse_message(raw_message: str, perf_ns: int) -> SSEMessage:
    """Parse a raw SSE message into an SSEMessage object.

    Parsing logic based on official HTML SSE Living Standard:
    https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
    """

    message = SSEMessage(perf_ns=perf_ns)
    for line in raw_message.split("\n"):
        if not (line := line.strip()):
            continue

        parts = line.split(":", 1)
        if len(parts) < 2:
            # Fields without a colon have no value, so the whole line is the field name
            message.packets.append(SSEField(name=parts[0].strip(), value=None))
            continue

        field_name, value = parts

        if field_name == "":
            # Field name is empty, so this is a comment
            field_name = SSEFieldType.COMMENT

        message.packets.append(SSEField(name=field_name.strip(), value=value.strip()))

    return message

aiperf.clients.http.defaults

AioHttpDefaults dataclass

Default values for aiohttp.ClientSession.

Source code in aiperf/clients/http/defaults.py
62
63
64
65
66
67
68
69
70
71
72
73
74
@dataclass(frozen=True)
class AioHttpDefaults:
    """Default values for aiohttp.ClientSession."""

    LIMIT = 2500  # Maximum number of concurrent connections
    LIMIT_PER_HOST = 2500  # Maximum number of concurrent connections per host
    TTL_DNS_CACHE = 300  # Time to live for DNS cache
    USE_DNS_CACHE = True  # Enable DNS cache
    ENABLE_CLEANUP_CLOSED = False  # Disable cleanup of closed connections
    FORCE_CLOSE = False  # Disable force close connections
    KEEPALIVE_TIMEOUT = 300  # Keepalive timeout
    HAPPY_EYEBALLS_DELAY = None  # Happy eyeballs delay (None = disabled)
    SOCKET_FAMILY = socket.AF_INET  # Family of the socket (IPv4)

SocketDefaults dataclass

Default values for socket options.

Source code in aiperf/clients/http/defaults.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class SocketDefaults:
    """
    Default values for socket options.
    """

    TCP_NODELAY = 1  # Disable Nagle's algorithm
    TCP_QUICKACK = 1  # Quick ACK mode

    SO_KEEPALIVE = 1  # Enable keepalive
    TCP_KEEPIDLE = 60  # Start keepalive after 1 min idle
    TCP_KEEPINTVL = 30  # Keepalive interval: 30 seconds
    TCP_KEEPCNT = 1  # 1 failed keepalive probes = dead

    SO_LINGER = 0  # Disable linger
    SO_REUSEADDR = 1  # Enable reuse address
    SO_REUSEPORT = 1  # Enable reuse port

    SO_RCVBUF = 1024 * 1024 * 10  # 10MB receive buffer
    SO_SNDBUF = 1024 * 1024 * 10  # 10MB send buffer

    SO_RCVTIMEO = 30  # 30 second receive timeout
    SO_SNDTIMEO = 30  # 30 second send timeout
    TCP_USER_TIMEOUT = 30000  # 30 sec user timeout

    @classmethod
    def apply_to_socket(cls, sock: socket.socket) -> None:
        """Apply the default socket options to the given socket."""

        # Low-latency optimizations for streaming
        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, cls.TCP_NODELAY)

        # Connection keepalive settings for long-lived SSE connections
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, cls.SO_KEEPALIVE)

        # Fine-tune keepalive timing (Linux-specific)
        if hasattr(socket, "TCP_KEEPIDLE"):
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, cls.TCP_KEEPIDLE)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, cls.TCP_KEEPINTVL)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, cls.TCP_KEEPCNT)

        # Buffer size optimizations for streaming
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, cls.SO_RCVBUF)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, cls.SO_SNDBUF)

        # Linux-specific TCP optimizations
        if hasattr(socket, "TCP_QUICKACK"):
            sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, cls.TCP_QUICKACK)

        if hasattr(socket, "TCP_USER_TIMEOUT"):
            sock.setsockopt(
                socket.SOL_TCP, socket.TCP_USER_TIMEOUT, cls.TCP_USER_TIMEOUT
            )

apply_to_socket(sock) classmethod

Apply the default socket options to the given socket.

Source code in aiperf/clients/http/defaults.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@classmethod
def apply_to_socket(cls, sock: socket.socket) -> None:
    """Apply the default socket options to the given socket."""

    # Low-latency optimizations for streaming
    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, cls.TCP_NODELAY)

    # Connection keepalive settings for long-lived SSE connections
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, cls.SO_KEEPALIVE)

    # Fine-tune keepalive timing (Linux-specific)
    if hasattr(socket, "TCP_KEEPIDLE"):
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, cls.TCP_KEEPIDLE)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, cls.TCP_KEEPINTVL)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, cls.TCP_KEEPCNT)

    # Buffer size optimizations for streaming
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, cls.SO_RCVBUF)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, cls.SO_SNDBUF)

    # Linux-specific TCP optimizations
    if hasattr(socket, "TCP_QUICKACK"):
        sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, cls.TCP_QUICKACK)

    if hasattr(socket, "TCP_USER_TIMEOUT"):
        sock.setsockopt(
            socket.SOL_TCP, socket.TCP_USER_TIMEOUT, cls.TCP_USER_TIMEOUT
        )

aiperf.clients.model_endpoint_info

Model endpoint information.

This module contains the pydantic models that encapsulate the information needed to send requests to an inference server, primarily around the model, endpoint, and additional request payload information.

EndpointInfo

Bases: AIPerfBaseModel

Information about an endpoint.

Source code in aiperf/clients/model_endpoint_info.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class EndpointInfo(AIPerfBaseModel):
    """Information about an endpoint."""

    type: EndpointType = Field(
        default=EndpointType.OPENAI_CHAT_COMPLETIONS,
        description="The type of request payload to use for the endpoint.",
    )
    base_url: str | None = Field(
        default=None,
        description="URL of the endpoint.",
    )
    custom_endpoint: str | None = Field(
        default=None,
        description="Custom endpoint to use for the models.",
    )
    url_params: dict[str, Any] | None = Field(
        default=None, description="Custom URL parameters to use for the endpoint."
    )
    streaming: bool = Field(
        default=False,
        description="Whether the endpoint supports streaming.",
    )
    headers: dict[str, str] | None = Field(
        default=None,
        description="Custom URL headers to use for the endpoint.",
    )
    api_key: str | None = Field(
        default=None,
        description="API key to use for the endpoint.",
    )
    ssl_options: dict[str, Any] | None = Field(
        default=None,
        description="SSL options to use for the endpoint.",
    )
    timeout: float = Field(
        default=EndPointDefaults.TIMEOUT,
        description="The timeout in seconds for each request to the endpoint.",
    )
    extra: dict[str, Any] | None = Field(
        default=None,
        description="Additional inputs to include with every request. "
        "You can repeat this flag for multiple inputs. Inputs should be in an 'input_name:value' format. "
        "Alternatively, a string representing a json formatted dict can be provided.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo":
        """Create an HttpEndpointInfo from a UserConfig."""
        return cls(
            type=EndpointType(user_config.endpoint.type),
            custom_endpoint=user_config.endpoint.custom_endpoint,
            streaming=user_config.endpoint.streaming,
            base_url=user_config.endpoint.url,
            headers=user_config.input.headers,
            extra=user_config.input.extra,
            timeout=user_config.endpoint.timeout_seconds,
            api_key=user_config.endpoint.api_key,
        )

from_user_config(user_config) classmethod

Create an HttpEndpointInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
106
107
108
109
110
111
112
113
114
115
116
117
118
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo":
    """Create an HttpEndpointInfo from a UserConfig."""
    return cls(
        type=EndpointType(user_config.endpoint.type),
        custom_endpoint=user_config.endpoint.custom_endpoint,
        streaming=user_config.endpoint.streaming,
        base_url=user_config.endpoint.url,
        headers=user_config.input.headers,
        extra=user_config.input.extra,
        timeout=user_config.endpoint.timeout_seconds,
        api_key=user_config.endpoint.api_key,
    )

ModelEndpointInfo

Bases: AIPerfBaseModel

Information about a model endpoint.

Source code in aiperf/clients/model_endpoint_info.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class ModelEndpointInfo(AIPerfBaseModel):
    """Information about a model endpoint."""

    models: ModelListInfo = Field(
        ...,
        description="The models to use for the endpoint.",
    )
    endpoint: EndpointInfo = Field(
        ...,
        description="The endpoint to use for the models.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "ModelEndpointInfo":
        """Create a ModelEndpointInfo from a UserConfig."""
        return cls(
            models=ModelListInfo.from_user_config(user_config),
            endpoint=EndpointInfo.from_user_config(user_config),
        )

    @property
    def url(self) -> str:
        """Get the full URL for the endpoint."""
        url = self.endpoint.base_url.rstrip("/") if self.endpoint.base_url else ""
        if self.endpoint.custom_endpoint:
            url += "/" + self.endpoint.custom_endpoint.lstrip("/")
        elif path := self.endpoint.type.endpoint_path():
            url += "/" + path.lstrip("/")
        return url

    @property
    def primary_model(self) -> ModelInfo:
        """Get the primary model."""
        return self.models.models[0]

    @property
    def primary_model_name(self) -> str:
        """Get the primary model name."""
        return self.primary_model.name

primary_model property

Get the primary model.

primary_model_name property

Get the primary model name.

url property

Get the full URL for the endpoint.

from_user_config(user_config) classmethod

Create a ModelEndpointInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
133
134
135
136
137
138
139
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "ModelEndpointInfo":
    """Create a ModelEndpointInfo from a UserConfig."""
    return cls(
        models=ModelListInfo.from_user_config(user_config),
        endpoint=EndpointInfo.from_user_config(user_config),
    )

ModelInfo

Bases: AIPerfBaseModel

Information about a model.

Source code in aiperf/clients/model_endpoint_info.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ModelInfo(AIPerfBaseModel):
    """Information about a model."""

    name: str = Field(
        ...,
        min_length=1,
        description="The name of the model. This is used to identify the model.",
    )
    version: str | None = Field(
        default=None,
        description="The version of the model.",
    )
    modality: Modality = Field(
        default=Modality.TEXT,
        description="The modality of the model. This is used to determine the type of request payload "
        "to use for the endpoint. If CUSTOM, the model is not supported.",
    )

ModelListInfo

Bases: AIPerfBaseModel

Information about a list of models.

Source code in aiperf/clients/model_endpoint_info.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ModelListInfo(AIPerfBaseModel):
    """Information about a list of models."""

    models: list[ModelInfo] = Field(
        ...,
        min_length=1,
        description="The models to use for the endpoint.",
    )
    model_selection_strategy: ModelSelectionStrategy = Field(
        ...,
        description="The strategy to use for selecting the model to use for the endpoint.",
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "ModelListInfo":
        """Create a ModelListInfo from a UserConfig."""
        return cls(
            models=[ModelInfo(name=model) for model in user_config.model_names],
            model_selection_strategy=user_config.endpoint.model_selection_strategy,
        )

from_user_config(user_config) classmethod

Create a ModelListInfo from a UserConfig.

Source code in aiperf/clients/model_endpoint_info.py
52
53
54
55
56
57
58
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "ModelListInfo":
    """Create a ModelListInfo from a UserConfig."""
    return cls(
        models=[ModelInfo(name=model) for model in user_config.model_names],
        model_selection_strategy=user_config.endpoint.model_selection_strategy,
    )

aiperf.clients.openai.openai_aiohttp

OpenAIClientAioHttp

Bases: AioHttpClientMixin, ABC

Inference client for OpenAI based requests using aiohttp.

Source code in aiperf/clients/openai/openai_aiohttp.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@InferenceClientFactory.register_all(
    EndpointType.OPENAI_CHAT_COMPLETIONS,
    EndpointType.OPENAI_COMPLETIONS,
    EndpointType.OPENAI_RESPONSES,
)
class OpenAIClientAioHttp(AioHttpClientMixin, ABC):
    """Inference client for OpenAI based requests using aiohttp."""

    def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
        super().__init__(model_endpoint)
        self.logger = logging.getLogger(self.__class__.__name__)
        self.model_endpoint = model_endpoint

    def get_headers(self, model_endpoint: ModelEndpointInfo) -> dict[str, str]:
        """Get the headers for the given endpoint."""

        accept = (
            "text/event-stream"
            if model_endpoint.endpoint.streaming
            else "application/json"
        )

        headers = {
            "User-Agent": "aiperf/1.0",
            "Content-Type": "application/json",
            "Accept": accept,
        }
        if model_endpoint.endpoint.api_key:
            headers["Authorization"] = f"Bearer {model_endpoint.endpoint.api_key}"
        if model_endpoint.endpoint.headers:
            headers.update(model_endpoint.endpoint.headers)
        return headers

    def get_url(self, model_endpoint: ModelEndpointInfo) -> str:
        """Get the URL for the given endpoint."""
        url = model_endpoint.url
        if not url.startswith("http"):
            url = f"http://{url}"
        return url

    async def send_request(
        self,
        model_endpoint: ModelEndpointInfo,
        payload: dict[str, Any],
    ) -> RequestRecord:
        """Send OpenAI request using aiohttp."""

        # capture start time before request is sent in the case of an error
        start_perf_ns = time.perf_counter_ns()
        try:
            self.logger.debug(
                "Sending OpenAI request to %s, payload: %s", model_endpoint.url, payload
            )

            record = await self.post_request(
                self.get_url(model_endpoint),
                json.dumps(payload),
                self.get_headers(model_endpoint),
            )
            record.request = payload

        except Exception as e:
            record = RequestRecord(
                request=payload,
                start_perf_ns=start_perf_ns,
                end_perf_ns=time.perf_counter_ns(),
                error=ErrorDetails(type=e.__class__.__name__, message=str(e)),
            )
            self.logger.exception(
                "Error in OpenAI request: %s %s",
                e.__class__.__name__,
                str(e),
            )

        return record

get_headers(model_endpoint)

Get the headers for the given endpoint.

Source code in aiperf/clients/openai/openai_aiohttp.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_headers(self, model_endpoint: ModelEndpointInfo) -> dict[str, str]:
    """Get the headers for the given endpoint."""

    accept = (
        "text/event-stream"
        if model_endpoint.endpoint.streaming
        else "application/json"
    )

    headers = {
        "User-Agent": "aiperf/1.0",
        "Content-Type": "application/json",
        "Accept": accept,
    }
    if model_endpoint.endpoint.api_key:
        headers["Authorization"] = f"Bearer {model_endpoint.endpoint.api_key}"
    if model_endpoint.endpoint.headers:
        headers.update(model_endpoint.endpoint.headers)
    return headers

get_url(model_endpoint)

Get the URL for the given endpoint.

Source code in aiperf/clients/openai/openai_aiohttp.py
55
56
57
58
59
60
def get_url(self, model_endpoint: ModelEndpointInfo) -> str:
    """Get the URL for the given endpoint."""
    url = model_endpoint.url
    if not url.startswith("http"):
        url = f"http://{url}"
    return url

send_request(model_endpoint, payload) async

Send OpenAI request using aiohttp.

Source code in aiperf/clients/openai/openai_aiohttp.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
async def send_request(
    self,
    model_endpoint: ModelEndpointInfo,
    payload: dict[str, Any],
) -> RequestRecord:
    """Send OpenAI request using aiohttp."""

    # capture start time before request is sent in the case of an error
    start_perf_ns = time.perf_counter_ns()
    try:
        self.logger.debug(
            "Sending OpenAI request to %s, payload: %s", model_endpoint.url, payload
        )

        record = await self.post_request(
            self.get_url(model_endpoint),
            json.dumps(payload),
            self.get_headers(model_endpoint),
        )
        record.request = payload

    except Exception as e:
        record = RequestRecord(
            request=payload,
            start_perf_ns=start_perf_ns,
            end_perf_ns=time.perf_counter_ns(),
            error=ErrorDetails(type=e.__class__.__name__, message=str(e)),
        )
        self.logger.exception(
            "Error in OpenAI request: %s %s",
            e.__class__.__name__,
            str(e),
        )

    return record

aiperf.clients.openai.openai_chat

OpenAIChatCompletionRequestConverter

Bases: RequestConverterProtocol[dict[str, Any]]

Request converter for OpenAI chat completion requests.

Source code in aiperf/clients/openai/openai_chat.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@RequestConverterFactory.register(EndpointType.OPENAI_CHAT_COMPLETIONS)
class OpenAIChatCompletionRequestConverter(RequestConverterProtocol[dict[str, Any]]):
    """Request converter for OpenAI chat completion requests."""

    def __init__(self) -> None:
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a chat completion request."""

        # TODO: Do we need to support image and audio inputs?
        messages = [
            {
                "role": turn.role or DEFAULT_ROLE,
                "name": text.name,
                "content": content,
            }
            for text in turn.texts
            for content in text.contents
            if content
        ]

        payload = {
            "messages": messages,
            "model": model_endpoint.primary_model_name,
            "stream": model_endpoint.endpoint.streaming,
        }

        if model_endpoint.endpoint.extra:
            payload.update(model_endpoint.endpoint.extra)

        self.logger.debug("Formatted payload: %s", payload)
        return payload

format_payload(model_endpoint, turn) async

Format payload for a chat completion request.

Source code in aiperf/clients/openai/openai_chat.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a chat completion request."""

    # TODO: Do we need to support image and audio inputs?
    messages = [
        {
            "role": turn.role or DEFAULT_ROLE,
            "name": text.name,
            "content": content,
        }
        for text in turn.texts
        for content in text.contents
        if content
    ]

    payload = {
        "messages": messages,
        "model": model_endpoint.primary_model_name,
        "stream": model_endpoint.endpoint.streaming,
    }

    if model_endpoint.endpoint.extra:
        payload.update(model_endpoint.endpoint.extra)

    self.logger.debug("Formatted payload: %s", payload)
    return payload

aiperf.clients.openai.openai_completions

OpenAICompletionRequestConverter

Bases: RequestConverterProtocol[dict[str, Any]]

Request converter for OpenAI completion requests.

Source code in aiperf/clients/openai/openai_completions.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@RequestConverterFactory.register(EndpointType.OPENAI_COMPLETIONS)
class OpenAICompletionRequestConverter(RequestConverterProtocol[dict[str, Any]]):
    """Request converter for OpenAI completion requests."""

    def __init__(self) -> None:
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a completion request."""

        # TODO: Do we need to support image and audio inputs?
        prompts = [
            content for text in turn.texts for content in text.contents if content
        ]

        extra = model_endpoint.endpoint.extra or {}

        payload = {
            "prompt": prompts,
            "model": model_endpoint.primary_model_name,
            "stream": model_endpoint.endpoint.streaming,
        }

        if extra:
            payload.update(extra)

        self.logger.debug("Formatted payload: %s", payload)
        return payload

format_payload(model_endpoint, turn) async

Format payload for a completion request.

Source code in aiperf/clients/openai/openai_completions.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a completion request."""

    # TODO: Do we need to support image and audio inputs?
    prompts = [
        content for text in turn.texts for content in text.contents if content
    ]

    extra = model_endpoint.endpoint.extra or {}

    payload = {
        "prompt": prompts,
        "model": model_endpoint.primary_model_name,
        "stream": model_endpoint.endpoint.streaming,
    }

    if extra:
        payload.update(extra)

    self.logger.debug("Formatted payload: %s", payload)
    return payload

aiperf.clients.openai.openai_responses

OpenAIResponsesRequestConverter

Bases: RequestConverterProtocol[dict[str, Any]]

Request converter for OpenAI Responses requests.

Source code in aiperf/clients/openai/openai_responses.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@RequestConverterFactory.register(EndpointType.OPENAI_RESPONSES)
class OpenAIResponsesRequestConverter(RequestConverterProtocol[dict[str, Any]]):
    """Request converter for OpenAI Responses requests."""

    def __init__(self) -> None:
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

    async def format_payload(
        self,
        model_endpoint: ModelEndpointInfo,
        turn: Turn,
    ) -> dict[str, Any]:
        """Format payload for a responses request."""

        # TODO: Add support for image and audio inputs.
        prompts = [
            content for text in turn.texts for content in text.contents if content
        ]

        extra = model_endpoint.endpoint.extra or {}

        payload = {
            "input": prompts,
            "model": model_endpoint.primary_model_name,
            # TODO: How do we handle max_output_tokens? Should be provided by OSL logic
            "max_output_tokens": extra.pop("max_output_tokens", None),
            "stream": model_endpoint.endpoint.streaming,
        }

        if extra:
            payload.update(extra)

        self.logger.debug("Formatted payload: %s", payload)
        return payload

format_payload(model_endpoint, turn) async

Format payload for a responses request.

Source code in aiperf/clients/openai/openai_responses.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def format_payload(
    self,
    model_endpoint: ModelEndpointInfo,
    turn: Turn,
) -> dict[str, Any]:
    """Format payload for a responses request."""

    # TODO: Add support for image and audio inputs.
    prompts = [
        content for text in turn.texts for content in text.contents if content
    ]

    extra = model_endpoint.endpoint.extra or {}

    payload = {
        "input": prompts,
        "model": model_endpoint.primary_model_name,
        # TODO: How do we handle max_output_tokens? Should be provided by OSL logic
        "max_output_tokens": extra.pop("max_output_tokens", None),
        "stream": model_endpoint.endpoint.streaming,
    }

    if extra:
        payload.update(extra)

    self.logger.debug("Formatted payload: %s", payload)
    return payload

aiperf.common.aiperf_logger

AIPerfLogger

Logger for AIPerf messages with lazy evaluation support for f-strings.

This logger supports lazy evaluation of f-strings through lambdas to avoid expensive string formatting operations when the log level is not enabled.

It also extends the standard logging module with additional log levels
  • TRACE (TRACE < DEBUG)
  • NOTICE (INFO < NOTICE < WARNING)
  • SUCCESS (WARNING < SUCCESS < ERROR)
Usage

logger = AIPerfLogger("my_logger") logger.debug(lambda: f"Processing {item} with {count} items") logger.info("Simple string message") logger.notice("Notice message") logger.success("Benchmark completed successfully")

Need to pass local variables to the lambda to avoid them going out of scope

logger.debug(lambda i=i: f"Binding loop variable: {i}") logger.exception(f"Direct f-string usage: {e}")

Source code in aiperf/common/aiperf_logger.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class AIPerfLogger:
    """Logger for AIPerf messages with lazy evaluation support for f-strings.

    This logger supports lazy evaluation of f-strings through lambdas to avoid
    expensive string formatting operations when the log level is not enabled.

    It also extends the standard logging module with additional log levels:
        - TRACE    (TRACE < DEBUG)
        - NOTICE   (INFO < NOTICE < WARNING)
        - SUCCESS  (WARNING < SUCCESS < ERROR)

    Usage:
        logger = AIPerfLogger("my_logger")
        logger.debug(lambda: f"Processing {item} with {count} items")
        logger.info("Simple string message")
        logger.notice("Notice message")
        logger.success("Benchmark completed successfully")
        # Need to pass local variables to the lambda to avoid them going out of scope
        logger.debug(lambda i=i: f"Binding loop variable: {i}")
        logger.exception(f"Direct f-string usage: {e}")
    """

    def __init__(self, logger_name: str):
        self.logger_name = logger_name
        self._logger = logging.getLogger(logger_name)

        # Cache the internal logging module's _log method
        self._internal_log = self._logger._log

        # Forward the internal findCaller method to our custom method
        self._logger.findCaller = self.find_caller

        # Python style method names
        self.is_enabled_for = self._logger.isEnabledFor
        self.set_level = self._logger.setLevel
        self.get_effective_level = self._logger.getEffectiveLevel

        # Legacy logging method compatibility / passthrough
        self.isEnabledFor = self._logger.isEnabledFor
        self.setLevel = self._logger.setLevel
        self.getEffectiveLevel = self._logger.getEffectiveLevel
        self.handlers = self._logger.handlers
        self.addHandler = self._logger.addHandler
        self.removeHandler = self._logger.removeHandler
        self.hasHandlers = self._logger.hasHandlers
        self.root = self._logger.root

    def _log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Internal log method that handles lazy evaluation of f-strings."""
        if callable(msg):
            # NOTE: Internal python logging _log method requires a tuple for the args, even if there are no args
            self._internal_log(level, msg(*args), (), **kwargs)
        else:
            self._internal_log(level, msg, args, **kwargs)

    @classmethod
    def is_valid_level(cls, level: int | str) -> bool:
        """Check if the given level is a valid level."""
        if isinstance(level, str):
            return level in [
                "TRACE",
                "DEBUG",
                "INFO",
                "NOTICE",
                "WARNING",
                "SUCCESS",
                "ERROR",
                "CRITICAL",
            ]
        else:
            return level in [
                _TRACE,
                _DEBUG,
                _INFO,
                _NOTICE,
                _WARNING,
                _SUCCESS,
                _ERROR,
                _CRITICAL,
            ]

    @classmethod
    def get_level_number(cls, level: int | str) -> int:
        """Get the numeric level for the given level."""
        if isinstance(level, str):
            return getattr(cls, level.upper())
        else:
            return level

    def find_caller(
        self, stack_info=False, stacklevel=1
    ) -> tuple[str, int, str, str | None]:
        """
        NOTE: This is a modified version of the findCaller method in the logging module,
        in order to allow us to add custom ignored files.

        Find the stack frame of the caller so that we can note the source
        file name, line number and function name.
        """
        f = currentframe()
        # On some versions of IronPython, currentframe() returns None if
        # IronPython isn't run with -X:Frames.
        if f is not None:
            f = f.f_back
        orig_f = f
        while f and stacklevel > 1:
            f = f.f_back
            stacklevel -= 1
        if not f:
            f = orig_f
        rv = "(unknown file)", 0, "(unknown function)", None
        while f and hasattr(f, "f_code"):
            co = f.f_code
            filename = os.path.normcase(co.co_filename)
            # NOTE: The if-statement below was modified to use our own list of ignored files (_ignored_files).
            # This is required to avoid it appearing as all logs are coming from this file.
            if filename in _ignored_files:
                f = f.f_back
                continue
            sinfo = None
            if stack_info:
                sio = io.StringIO()
                sio.write("Stack (most recent call last):\n")
                traceback.print_stack(f, file=sio)
                sinfo = sio.getvalue()
                if sinfo[-1] == "\n":
                    sinfo = sinfo[:-1]
                sio.close()
            rv = (co.co_filename, f.f_lineno, co.co_name, sinfo)
            break
        return rv

    def log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(level):
            self._log(level, msg, args, **kwargs)

    def trace(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a trace message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_TRACE):
            self._log(_TRACE, msg, *args, **kwargs)

    def debug(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a debug message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_DEBUG):
            self._log(_DEBUG, msg, *args, **kwargs)

    def info(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an info message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_INFO):
            self._log(_INFO, msg, *args, **kwargs)

    def notice(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a notice message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_NOTICE):
            self._log(_NOTICE, msg, *args, **kwargs)

    def warning(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a warning message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_WARNING):
            self._log(_WARNING, msg, *args, **kwargs)

    def success(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a success message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_SUCCESS):
            self._log(_SUCCESS, msg, *args, **kwargs)

    def error(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an error message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, msg, *args, **kwargs)

    def exception(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an exception message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, msg, *args, exc_info=True, **kwargs)

    def critical(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a critical message with support for lazy evaluation using lambdas."""
        if self.is_enabled_for(_CRITICAL):
            self._log(_CRITICAL, msg, *args, **kwargs)

critical(msg, *args, **kwargs)

Log a critical message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
202
203
204
205
def critical(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a critical message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_CRITICAL):
        self._log(_CRITICAL, msg, *args, **kwargs)

debug(msg, *args, **kwargs)

Log a debug message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
167
168
169
170
def debug(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a debug message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_DEBUG):
        self._log(_DEBUG, msg, *args, **kwargs)

error(msg, *args, **kwargs)

Log an error message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
192
193
194
195
def error(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an error message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, msg, *args, **kwargs)

exception(msg, *args, **kwargs)

Log an exception message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
197
198
199
200
def exception(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an exception message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, msg, *args, exc_info=True, **kwargs)

find_caller(stack_info=False, stacklevel=1)

NOTE: This is a modified version of the findCaller method in the logging module, in order to allow us to add custom ignored files.

Find the stack frame of the caller so that we can note the source file name, line number and function name.

Source code in aiperf/common/aiperf_logger.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def find_caller(
    self, stack_info=False, stacklevel=1
) -> tuple[str, int, str, str | None]:
    """
    NOTE: This is a modified version of the findCaller method in the logging module,
    in order to allow us to add custom ignored files.

    Find the stack frame of the caller so that we can note the source
    file name, line number and function name.
    """
    f = currentframe()
    # On some versions of IronPython, currentframe() returns None if
    # IronPython isn't run with -X:Frames.
    if f is not None:
        f = f.f_back
    orig_f = f
    while f and stacklevel > 1:
        f = f.f_back
        stacklevel -= 1
    if not f:
        f = orig_f
    rv = "(unknown file)", 0, "(unknown function)", None
    while f and hasattr(f, "f_code"):
        co = f.f_code
        filename = os.path.normcase(co.co_filename)
        # NOTE: The if-statement below was modified to use our own list of ignored files (_ignored_files).
        # This is required to avoid it appearing as all logs are coming from this file.
        if filename in _ignored_files:
            f = f.f_back
            continue
        sinfo = None
        if stack_info:
            sio = io.StringIO()
            sio.write("Stack (most recent call last):\n")
            traceback.print_stack(f, file=sio)
            sinfo = sio.getvalue()
            if sinfo[-1] == "\n":
                sinfo = sinfo[:-1]
            sio.close()
        rv = (co.co_filename, f.f_lineno, co.co_name, sinfo)
        break
    return rv

get_level_number(level) classmethod

Get the numeric level for the given level.

Source code in aiperf/common/aiperf_logger.py
106
107
108
109
110
111
112
@classmethod
def get_level_number(cls, level: int | str) -> int:
    """Get the numeric level for the given level."""
    if isinstance(level, str):
        return getattr(cls, level.upper())
    else:
        return level

info(msg, *args, **kwargs)

Log an info message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
172
173
174
175
def info(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an info message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_INFO):
        self._log(_INFO, msg, *args, **kwargs)

is_valid_level(level) classmethod

Check if the given level is a valid level.

Source code in aiperf/common/aiperf_logger.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@classmethod
def is_valid_level(cls, level: int | str) -> bool:
    """Check if the given level is a valid level."""
    if isinstance(level, str):
        return level in [
            "TRACE",
            "DEBUG",
            "INFO",
            "NOTICE",
            "WARNING",
            "SUCCESS",
            "ERROR",
            "CRITICAL",
        ]
    else:
        return level in [
            _TRACE,
            _DEBUG,
            _INFO,
            _NOTICE,
            _WARNING,
            _SUCCESS,
            _ERROR,
            _CRITICAL,
        ]

log(level, msg, *args, **kwargs)

Log a message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
157
158
159
160
def log(self, level: int, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(level):
        self._log(level, msg, args, **kwargs)

notice(msg, *args, **kwargs)

Log a notice message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
177
178
179
180
def notice(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a notice message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_NOTICE):
        self._log(_NOTICE, msg, *args, **kwargs)

success(msg, *args, **kwargs)

Log a success message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
187
188
189
190
def success(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a success message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_SUCCESS):
        self._log(_SUCCESS, msg, *args, **kwargs)

trace(msg, *args, **kwargs)

Log a trace message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
162
163
164
165
def trace(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a trace message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_TRACE):
        self._log(_TRACE, msg, *args, **kwargs)

warning(msg, *args, **kwargs)

Log a warning message with support for lazy evaluation using lambdas.

Source code in aiperf/common/aiperf_logger.py
182
183
184
185
def warning(self, msg: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a warning message with support for lazy evaluation using lambdas."""
    if self.is_enabled_for(_WARNING):
        self._log(_WARNING, msg, *args, **kwargs)

aiperf.common.bootstrap

bootstrap_and_run_service(service_class, service_config=None, user_config=None, service_id=None, log_queue=None, **kwargs)

Bootstrap the service and run it.

This function will load the service configuration, create an instance of the service, and run it.

Parameters:

Name Type Description Default
service_class type[BaseService]

The python class of the service to run. This should be a subclass of BaseService. This should be a type and not an instance.

required
service_config ServiceConfig | None

The service configuration to use. If not provided, the service configuration will be loaded from the environment variables.

None
user_config UserConfig | None

The user configuration to use. If not provided, the user configuration will be loaded from the environment variables.

None
log_queue Queue | None

Optional multiprocessing queue for child process logging. If provided, the child process logging will be set up.

None
kwargs

Additional keyword arguments to pass to the service constructor.

{}
Source code in aiperf/common/bootstrap.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def bootstrap_and_run_service(
    service_class: type[BaseService],
    service_config: ServiceConfig | None = None,
    user_config: UserConfig | None = None,
    service_id: str | None = None,
    log_queue: "multiprocessing.Queue | None" = None,
    **kwargs,
):
    """Bootstrap the service and run it.

    This function will load the service configuration,
    create an instance of the service, and run it.

    Args:
        service_class: The python class of the service to run. This should be a subclass of
            BaseService. This should be a type and not an instance.
        service_config: The service configuration to use. If not provided, the service
            configuration will be loaded from the environment variables.
        user_config: The user configuration to use. If not provided, the user configuration
            will be loaded from the environment variables.
        log_queue: Optional multiprocessing queue for child process logging. If provided,
            the child process logging will be set up.
        kwargs: Additional keyword arguments to pass to the service constructor.
    """

    # Load the service configuration
    if service_config is None:
        from aiperf.common.config import load_service_config

        service_config = load_service_config()

    # Load the user configuration
    if user_config is None:
        from aiperf.common.config import load_user_config

        # TODO: Add support for loading user config from a file/environment variables
        user_config = load_user_config()

    async def _run_service():
        if service_config.enable_yappi:
            _start_yappi_profiling()

        service = service_class(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

        from aiperf.common.logging import setup_child_process_logging

        setup_child_process_logging(
            log_queue, service.service_id, service_config, user_config
        )

        if user_config.input.random_seed is not None:
            random.seed(user_config.input.random_seed)
            # Try and set the numpy random seed
            # https://numpy.org/doc/stable/reference/random/index.html#random-quick-start
            with contextlib.suppress(ImportError):
                import numpy as np

                np.random.seed(user_config.input.random_seed)

        with contextlib.suppress(asyncio.CancelledError):
            await service.run_forever()

        if service_config.enable_yappi:
            _stop_yappi_profiling(service.service_id, user_config)

    with contextlib.suppress(asyncio.CancelledError):
        if service_config.enable_uvloop:
            import uvloop

            uvloop.run(_run_service())
        else:
            asyncio.run(_run_service())

aiperf.common.comms.base

BaseCommunication

Bases: ABC

Base class for specifying the base communication layer for AIPerf components.

Source code in aiperf/common/comms/base.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class BaseCommunication(ABC):
    """Base class for specifying the base communication layer for AIPerf components."""

    @abstractmethod
    async def initialize(self) -> None:
        """Initialize communication channels."""

    @property
    @abstractmethod
    def is_initialized(self) -> bool:
        """Check if communication channels are initialized.

        Returns:
            True if communication channels are initialized, False otherwise
        """

    @property
    @abstractmethod
    def stop_requested(self) -> bool:
        """Check if the communication channels are being shutdown.

        Returns:
            True if the communication channels are being shutdown, False otherwise
        """

    @abstractmethod
    async def shutdown(self) -> None:
        """Gracefully shutdown communication channels."""

    @abstractmethod
    def get_address(self, address_type: CommunicationClientAddressType | str) -> str:
        """Get the address for a given address type.

        Args:
            address_type: The type of address to get the address for, or the address itself.

        Returns:
            The address for the given address type, or the address itself if it is a string.
        """

    @abstractmethod
    def create_client(
        self,
        client_type: CommunicationClientType,
        address: CommunicationClientAddressType | str,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> CommunicationClientProtocol:
        """Create a communication client for a given client type and address.

        Args:
            client_type: The type of client to create.
            address: The type of address to use when looking up in the communication config, or the address itself.
            bind: Whether to bind or connect the socket.
            socket_ops: Additional socket options to set.
        """

    create_pub_client = _create_specific_client(
        CommunicationClientType.PUB, PubClientProtocol
    )
    create_sub_client = _create_specific_client(
        CommunicationClientType.SUB, SubClientProtocol
    )
    create_push_client = _create_specific_client(
        CommunicationClientType.PUSH, PushClientProtocol
    )
    create_pull_client = _create_specific_client(
        CommunicationClientType.PULL, PullClientProtocol
    )
    create_request_client = _create_specific_client(
        CommunicationClientType.REQUEST, RequestClientProtocol
    )
    create_reply_client = _create_specific_client(
        CommunicationClientType.REPLY, ReplyClientProtocol
    )

is_initialized abstractmethod property

Check if communication channels are initialized.

Returns:

Type Description
bool

True if communication channels are initialized, False otherwise

stop_requested abstractmethod property

Check if the communication channels are being shutdown.

Returns:

Type Description
bool

True if the communication channels are being shutdown, False otherwise

create_client(client_type, address, bind=False, socket_ops=None) abstractmethod

Create a communication client for a given client type and address.

Parameters:

Name Type Description Default
client_type CommunicationClientType

The type of client to create.

required
address CommunicationClientAddressType | str

The type of address to use when looking up in the communication config, or the address itself.

required
bind bool

Whether to bind or connect the socket.

False
socket_ops dict | None

Additional socket options to set.

None
Source code in aiperf/common/comms/base.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@abstractmethod
def create_client(
    self,
    client_type: CommunicationClientType,
    address: CommunicationClientAddressType | str,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> CommunicationClientProtocol:
    """Create a communication client for a given client type and address.

    Args:
        client_type: The type of client to create.
        address: The type of address to use when looking up in the communication config, or the address itself.
        bind: Whether to bind or connect the socket.
        socket_ops: Additional socket options to set.
    """

get_address(address_type) abstractmethod

Get the address for a given address type.

Parameters:

Name Type Description Default
address_type CommunicationClientAddressType | str

The type of address to get the address for, or the address itself.

required

Returns:

Type Description
str

The address for the given address type, or the address itself if it is a string.

Source code in aiperf/common/comms/base.py
247
248
249
250
251
252
253
254
255
256
@abstractmethod
def get_address(self, address_type: CommunicationClientAddressType | str) -> str:
    """Get the address for a given address type.

    Args:
        address_type: The type of address to get the address for, or the address itself.

    Returns:
        The address for the given address type, or the address itself if it is a string.
    """

initialize() abstractmethod async

Initialize communication channels.

Source code in aiperf/common/comms/base.py
221
222
223
@abstractmethod
async def initialize(self) -> None:
    """Initialize communication channels."""

shutdown() abstractmethod async

Gracefully shutdown communication channels.

Source code in aiperf/common/comms/base.py
243
244
245
@abstractmethod
async def shutdown(self) -> None:
    """Gracefully shutdown communication channels."""

CommunicationClientFactory

Bases: FactoryMixin[CommunicationClientType, CommunicationClientProtocol]

Factory for registering and creating BaseCommunicationClient instances based on the specified client type.

Example:

    # Register a new client type
    @CommunicationClientFactory.register(ClientType.PUB)
    class ZMQPubClient(BaseZMQClient):
        pass

    # Create a new client instance
    client = CommunicationClientFactory.create_instance(
        ClientType.PUB,
        address=ClientAddressType.SERVICE_XSUB_FRONTEND,
        bind=False,
    )
Source code in aiperf/common/comms/base.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class CommunicationClientFactory(
    FactoryMixin[CommunicationClientType, CommunicationClientProtocol]
):
    """Factory for registering and creating BaseCommunicationClient instances based on the specified client type.

    Example:
    ```python
        # Register a new client type
        @CommunicationClientFactory.register(ClientType.PUB)
        class ZMQPubClient(BaseZMQClient):
            pass

        # Create a new client instance
        client = CommunicationClientFactory.create_instance(
            ClientType.PUB,
            address=ClientAddressType.SERVICE_XSUB_FRONTEND,
            bind=False,
        )
    ```
    """

CommunicationClientProtocol

Bases: Protocol

Base interface for specifying the base communication client for AIPerf components.

Source code in aiperf/common/comms/base.py
22
23
24
25
26
27
28
29
30
31
32
@runtime_checkable
class CommunicationClientProtocol(Protocol):
    """Base interface for specifying the base communication client for AIPerf components."""

    async def initialize(self) -> None:
        """Initialize communication channels."""
        ...

    async def shutdown(self) -> None:
        """Shutdown communication channels."""
        ...

initialize() async

Initialize communication channels.

Source code in aiperf/common/comms/base.py
26
27
28
async def initialize(self) -> None:
    """Initialize communication channels."""
    ...

shutdown() async

Shutdown communication channels.

Source code in aiperf/common/comms/base.py
30
31
32
async def shutdown(self) -> None:
    """Shutdown communication channels."""
    ...

CommunicationClientProtocolFactory

Bases: FactoryMixin[CommunicationClientType, CommunicationClientProtocol]

Factory for registering CommunicationClientProtocol interfaces for dynamic client creation.

Source code in aiperf/common/comms/base.py
35
36
37
38
class CommunicationClientProtocolFactory(
    FactoryMixin[CommunicationClientType, CommunicationClientProtocol]
):
    """Factory for registering CommunicationClientProtocol interfaces for dynamic client creation."""

CommunicationFactory

Bases: FactoryMixin[CommunicationBackend, BaseCommunication]

Factory for registering and creating BaseCommunication instances based on the specified communication backend. See :class:FactoryMixin for more details.

Source code in aiperf/common/comms/base.py
295
296
297
298
class CommunicationFactory(FactoryMixin[CommunicationBackend, BaseCommunication]):
    """Factory for registering and creating BaseCommunication instances based on the specified communication backend.
    See :class:`FactoryMixin` for more details.
    """

PubClientProtocol

Bases: CommunicationClientProtocol

Interface for publish clients.

Source code in aiperf/common/comms/base.py
147
148
149
150
151
152
153
@CommunicationClientProtocolFactory.register(CommunicationClientType.PUB)
class PubClientProtocol(CommunicationClientProtocol):
    """Interface for publish clients."""

    async def publish(self, message: MessageT) -> None:
        """Publish a message. The message will be routed automatically based on the message type."""
        ...

publish(message) async

Publish a message. The message will be routed automatically based on the message type.

Source code in aiperf/common/comms/base.py
151
152
153
async def publish(self, message: MessageT) -> None:
    """Publish a message. The message will be routed automatically based on the message type."""
    ...

PullClientProtocol

Bases: CommunicationClientProtocol

Interface for pull clients.

Source code in aiperf/common/comms/base.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@CommunicationClientProtocolFactory.register(CommunicationClientType.PULL)
class PullClientProtocol(CommunicationClientProtocol):
    """Interface for pull clients."""

    async def register_pull_callback(
        self,
        message_type: MessageTypeT,
        callback: Callable[[MessageT], Coroutine[Any, Any, None]],
        max_concurrency: int | None = None,
    ) -> None:
        """Register a callback for a pull client. The callback will be called when
        a message is received for the given message type.

        Args:
            message_type: The message type to register the callback for
            callback: The callback to register
            max_concurrency: The maximum number of concurrent requests to allow.
        """
        ...

register_pull_callback(message_type, callback, max_concurrency=None) async

Register a callback for a pull client. The callback will be called when a message is received for the given message type.

Parameters:

Name Type Description Default
message_type MessageTypeT

The message type to register the callback for

required
callback Callable[[MessageT], Coroutine[Any, Any, None]]

The callback to register

required
max_concurrency int | None

The maximum number of concurrent requests to allow.

None
Source code in aiperf/common/comms/base.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
async def register_pull_callback(
    self,
    message_type: MessageTypeT,
    callback: Callable[[MessageT], Coroutine[Any, Any, None]],
    max_concurrency: int | None = None,
) -> None:
    """Register a callback for a pull client. The callback will be called when
    a message is received for the given message type.

    Args:
        message_type: The message type to register the callback for
        callback: The callback to register
        max_concurrency: The maximum number of concurrent requests to allow.
    """
    ...

PushClientProtocol

Bases: CommunicationClientProtocol

Interface for push clients.

Source code in aiperf/common/comms/base.py
41
42
43
44
45
46
47
48
49
50
51
52
@CommunicationClientProtocolFactory.register(CommunicationClientType.PUSH)
class PushClientProtocol(CommunicationClientProtocol):
    """Interface for push clients."""

    async def push(self, message: Message) -> None:
        """Push data to a target. The message will be routed automatically
        based on the message.message_type.

        Args:
            message: Message to be pushed
        """
        ...

push(message) async

Push data to a target. The message will be routed automatically based on the message.message_type.

Parameters:

Name Type Description Default
message Message

Message to be pushed

required
Source code in aiperf/common/comms/base.py
45
46
47
48
49
50
51
52
async def push(self, message: Message) -> None:
    """Push data to a target. The message will be routed automatically
    based on the message.message_type.

    Args:
        message: Message to be pushed
    """
    ...

ReplyClientProtocol

Bases: CommunicationClientProtocol

Interface for reply clients.

Source code in aiperf/common/comms/base.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@CommunicationClientProtocolFactory.register(CommunicationClientType.REPLY)
class ReplyClientProtocol(CommunicationClientProtocol):
    """Interface for reply clients."""

    def register_request_handler(
        self,
        service_id: str,
        message_type: MessageTypeT,
        handler: Callable[[MessageT], Coroutine[Any, Any, MessageOutputT | None]],
    ) -> None:
        """Register a request handler for a message type. The handler will be called when
        a request is received for the given message type.

        Args:
            service_id: The service ID to register the handler for
            message_type: The message type to register the handler for
            handler: The handler to register
        """
        ...

register_request_handler(service_id, message_type, handler)

Register a request handler for a message type. The handler will be called when a request is received for the given message type.

Parameters:

Name Type Description Default
service_id str

The service ID to register the handler for

required
message_type MessageTypeT

The message type to register the handler for

required
handler Callable[[MessageT], Coroutine[Any, Any, MessageOutputT | None]]

The handler to register

required
Source code in aiperf/common/comms/base.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def register_request_handler(
    self,
    service_id: str,
    message_type: MessageTypeT,
    handler: Callable[[MessageT], Coroutine[Any, Any, MessageOutputT | None]],
) -> None:
    """Register a request handler for a message type. The handler will be called when
    a request is received for the given message type.

    Args:
        service_id: The service ID to register the handler for
        message_type: The message type to register the handler for
        handler: The handler to register
    """
    ...

RequestClientProtocol

Bases: CommunicationClientProtocol

Interface for request clients.

Source code in aiperf/common/comms/base.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@CommunicationClientProtocolFactory.register(CommunicationClientType.REQUEST)
class RequestClientProtocol(CommunicationClientProtocol):
    """Interface for request clients."""

    async def request(
        self,
        message: MessageT,  # type: ignore[type-arg]
        timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
    ) -> MessageOutputT:  # type: ignore[type-arg]
        """Send a request and wait for a response. The message will be routed automatically
        based on the message type.

        Args:
            message: Message to send (will be routed based on the message type)
            timeout: Timeout in seconds for the request.

        Returns:
            Response message if successful
        """
        ...

    async def request_async(
        self,
        message: MessageT,
        callback: Callable[[MessageOutputT], Coroutine[Any, Any, None]],
    ) -> None:
        """Send a request and be notified when the response is received. The message will be routed automatically
        based on the message type.

        Args:
            message: Message to send (will be routed based on the message type)
            callback: Callback to be called when the response is received
        """
        ...

request(message, timeout=DEFAULT_COMMS_REQUEST_TIMEOUT) async

Send a request and wait for a response. The message will be routed automatically based on the message type.

Parameters:

Name Type Description Default
message MessageT

Message to send (will be routed based on the message type)

required
timeout float

Timeout in seconds for the request.

DEFAULT_COMMS_REQUEST_TIMEOUT

Returns:

Type Description
MessageOutputT

Response message if successful

Source code in aiperf/common/comms/base.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
async def request(
    self,
    message: MessageT,  # type: ignore[type-arg]
    timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
) -> MessageOutputT:  # type: ignore[type-arg]
    """Send a request and wait for a response. The message will be routed automatically
    based on the message type.

    Args:
        message: Message to send (will be routed based on the message type)
        timeout: Timeout in seconds for the request.

    Returns:
        Response message if successful
    """
    ...

request_async(message, callback) async

Send a request and be notified when the response is received. The message will be routed automatically based on the message type.

Parameters:

Name Type Description Default
message MessageT

Message to send (will be routed based on the message type)

required
callback Callable[[MessageOutputT], Coroutine[Any, Any, None]]

Callback to be called when the response is received

required
Source code in aiperf/common/comms/base.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
async def request_async(
    self,
    message: MessageT,
    callback: Callable[[MessageOutputT], Coroutine[Any, Any, None]],
) -> None:
    """Send a request and be notified when the response is received. The message will be routed automatically
    based on the message type.

    Args:
        message: Message to send (will be routed based on the message type)
        callback: Callback to be called when the response is received
    """
    ...

SubClientProtocol

Bases: CommunicationClientProtocol

Interface for subscribe clients.

Source code in aiperf/common/comms/base.py
133
134
135
136
137
138
139
140
141
142
143
144
@CommunicationClientProtocolFactory.register(CommunicationClientType.SUB)
class SubClientProtocol(CommunicationClientProtocol):
    """Interface for subscribe clients."""

    async def subscribe(
        self,
        message_type: MessageTypeT,
        callback: Callable[[MessageT], Coroutine[Any, Any, None]],
    ) -> None:
        """Subscribe to a specific message type. The callback will be called when
        a message is received for the given message type."""
        ...

subscribe(message_type, callback) async

Subscribe to a specific message type. The callback will be called when a message is received for the given message type.

Source code in aiperf/common/comms/base.py
137
138
139
140
141
142
143
144
async def subscribe(
    self,
    message_type: MessageTypeT,
    callback: Callable[[MessageT], Coroutine[Any, Any, None]],
) -> None:
    """Subscribe to a specific message type. The callback will be called when
    a message is received for the given message type."""
    ...

aiperf.common.comms.zmq.dealer_request_client

ZMQDealerRequestClient

Bases: BaseZMQClient, AsyncTaskManagerMixin

ZMQ DEALER socket client for asynchronous request-response communication.

The DEALER socket connects to ROUTER sockets and can send requests asynchronously, receiving responses through callbacks or awaitable futures.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ DEALER │───── Request ─────>│ ROUTER │ │ (Client) │ │ (Service) │ │ │<─── Response ──────│ │ └──────────────┘ └──────────────┘

Usage Pattern: - DEALER Clients send requests to ROUTER Services - Responses are routed back to the originating DEALER

DEALER/ROUTER is a Many-to-One communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQDealerRouterProxy for more details.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
@CommunicationClientFactory.register(CommunicationClientType.REQUEST)
class ZMQDealerRequestClient(BaseZMQClient, AsyncTaskManagerMixin):
    """
    ZMQ DEALER socket client for asynchronous request-response communication.

    The DEALER socket connects to ROUTER sockets and can send requests asynchronously,
    receiving responses through callbacks or awaitable futures.

    ASCII Diagram:
    ┌──────────────┐                    ┌──────────────┐
    │    DEALER    │───── Request ─────>│    ROUTER    │
    │   (Client)   │                    │  (Service)   │
    │              │<─── Response ──────│              │
    └──────────────┘                    └──────────────┘

    Usage Pattern:
    - DEALER Clients send requests to ROUTER Services
    - Responses are routed back to the originating DEALER

    DEALER/ROUTER is a Many-to-One communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQDealerRouterProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
    ) -> None:
        """
        Initialize the ZMQ Dealer (Req) client class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(context, zmq.SocketType.DEALER, address, bind, socket_ops)

        self.request_callbacks: dict[
            str, Callable[[Message], Coroutine[Any, Any, None]]
        ] = {}

    @aiperf_task
    async def _request_async_task(self) -> None:
        """Task to handle incoming requests."""
        while not self.stop_event.is_set():
            try:
                message = await self._socket.recv_string()
                self.trace(lambda msg=message: f"Received response: {msg}")
                response_message = Message.from_json(message)

                # Call the callback if it exists
                if response_message.request_id in self.request_callbacks:
                    callback = self.request_callbacks.pop(response_message.request_id)
                    self.execute_async(callback(response_message))

            except zmq.Again:
                self.trace(lambda: "No data received, yielding to event loop")
                await yield_to_event_loop()
                continue

            except (asyncio.CancelledError, zmq.ContextTerminated):
                raise  # re-raise the cancelled error

            except Exception as e:
                self.exception(f"Exception receiving responses: {e}")
                await yield_to_event_loop()
                continue

    @on_stop
    async def _stop_remaining_tasks(self) -> None:
        """Wait for all tasks to complete."""
        await self.cancel_all_tasks()

    async def request_async(
        self,
        message: Message,
        callback: Callable[[Message], Coroutine[Any, Any, None]],
    ) -> None:
        """Send a request and be notified when the response is received."""
        await self._ensure_initialized()

        if not isinstance(message, Message):
            raise TypeError(
                f"message must be an instance of Message, got {type(message).__name__}"
            )

        # Generate request ID if not provided so that responses can be matched
        if not message.request_id:
            message.request_id = str(uuid.uuid4())

        self.request_callbacks[message.request_id] = callback

        request_json = message.model_dump_json()
        self.trace(lambda msg=request_json: f"Sending request: {msg}")

        try:
            await self._socket.send_string(request_json)

        except Exception as e:
            raise CommunicationError(
                f"Exception sending request: {e.__class__.__qualname__} {e}",
            ) from e

    async def request(
        self,
        message: Message,
        timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
    ) -> Message:
        """Send a request and wait for a response up to timeout seconds.

        Args:
            message (Message): The request message to send.
            timeout (float): Maximum time to wait for a response in seconds.

        Returns:
            Message: The response message received.

        Raises:
            CommunicationError: if the request fails, or
            asyncio.TimeoutError: if the response is not received in time.
        """
        future = asyncio.Future[Message]()

        async def callback(response_message: Message) -> None:
            future.set_result(response_message)

        await self.request_async(message, callback)
        return await asyncio.wait_for(future, timeout=timeout)

__init__(context, address, bind, socket_ops=None)

Initialize the ZMQ Dealer (Req) client class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/dealer_request_client.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
) -> None:
    """
    Initialize the ZMQ Dealer (Req) client class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(context, zmq.SocketType.DEALER, address, bind, socket_ops)

    self.request_callbacks: dict[
        str, Callable[[Message], Coroutine[Any, Any, None]]
    ] = {}

request(message, timeout=DEFAULT_COMMS_REQUEST_TIMEOUT) async

Send a request and wait for a response up to timeout seconds.

Parameters:

Name Type Description Default
message Message

The request message to send.

required
timeout float

Maximum time to wait for a response in seconds.

DEFAULT_COMMS_REQUEST_TIMEOUT

Returns:

Name Type Description
Message Message

The response message received.

Raises:

Type Description
CommunicationError

if the request fails, or

TimeoutError

if the response is not received in time.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
async def request(
    self,
    message: Message,
    timeout: float = DEFAULT_COMMS_REQUEST_TIMEOUT,
) -> Message:
    """Send a request and wait for a response up to timeout seconds.

    Args:
        message (Message): The request message to send.
        timeout (float): Maximum time to wait for a response in seconds.

    Returns:
        Message: The response message received.

    Raises:
        CommunicationError: if the request fails, or
        asyncio.TimeoutError: if the response is not received in time.
    """
    future = asyncio.Future[Message]()

    async def callback(response_message: Message) -> None:
        future.set_result(response_message)

    await self.request_async(message, callback)
    return await asyncio.wait_for(future, timeout=timeout)

request_async(message, callback) async

Send a request and be notified when the response is received.

Source code in aiperf/common/comms/zmq/dealer_request_client.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
async def request_async(
    self,
    message: Message,
    callback: Callable[[Message], Coroutine[Any, Any, None]],
) -> None:
    """Send a request and be notified when the response is received."""
    await self._ensure_initialized()

    if not isinstance(message, Message):
        raise TypeError(
            f"message must be an instance of Message, got {type(message).__name__}"
        )

    # Generate request ID if not provided so that responses can be matched
    if not message.request_id:
        message.request_id = str(uuid.uuid4())

    self.request_callbacks[message.request_id] = callback

    request_json = message.model_dump_json()
    self.trace(lambda msg=request_json: f"Sending request: {msg}")

    try:
        await self._socket.send_string(request_json)

    except Exception as e:
        raise CommunicationError(
            f"Exception sending request: {e.__class__.__qualname__} {e}",
        ) from e

aiperf.common.comms.zmq.pub_client

ZMQPubClient

Bases: BaseZMQClient

The PUB socket broadcasts messages to all connected SUB sockets that have subscribed to the message topic/type.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ │ SUB │ ┌──────────────┐ │ (Subscriber) │ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ └──────────────┘ OR ┌──────────────┐ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ │ PUB │ └──────────────┘ │ (Publisher) │ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ └──────────────┘ └──────────────┘

Usage Pattern: - Single PUB socket broadcasts messages to all subscribers (One-to-Many) OR - Multiple PUB sockets broadcast messages to a single SUB socket (Many-to-One)

  • SUB sockets filter messages by topic/message_type
  • Fire-and-forget messaging (no acknowledgments)

PUB/SUB is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQXPubXSubProxy for more details.

Source code in aiperf/common/comms/zmq/pub_client.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@CommunicationClientFactory.register(CommunicationClientType.PUB)
class ZMQPubClient(BaseZMQClient):
    """
    The PUB socket broadcasts messages to all connected SUB sockets that have
    subscribed to the message topic/type.

    ASCII Diagram:
    ┌──────────────┐    ┌──────────────┐
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    │     SUB      │
    ┌──────────────┐    │ (Subscriber) │
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    └──────────────┘
    OR
    ┌──────────────┐    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    │     PUB      │    └──────────────┘
    │ (Publisher)  │    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    └──────────────┘    └──────────────┘

    Usage Pattern:
    - Single PUB socket broadcasts messages to all subscribers (One-to-Many)
    OR
    - Multiple PUB sockets broadcast messages to a single SUB socket (Many-to-One)

    - SUB sockets filter messages by topic/message_type
    - Fire-and-forget messaging (no acknowledgments)

    PUB/SUB is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQXPubXSubProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
    ) -> None:
        """
        Initialize the ZMQ Publisher client class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(context, zmq.SocketType.PUB, address, bind, socket_ops)

    async def publish(self, message: Message) -> None:
        """Publish a message. The topic will be set automatically based on the message type.

        Args:
            message: Message to publish (must be a Message object)
        """
        await self._ensure_initialized()

        try:
            message_json = message.model_dump_json()

            # Publish message
            await self.socket.send_multipart(
                [message.message_type.encode(), message_json.encode()]
            )

        except (asyncio.CancelledError, zmq.ContextTerminated):
            self.trace(
                lambda: f"Pub client {self.client_id} cancelled or context terminated"
            )
            return

        except Exception as e:
            raise CommunicationError(
                f"Failed to publish message {message.message_type}: {e}",
            ) from e

__init__(context, address, bind, socket_ops=None)

Initialize the ZMQ Publisher client class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/pub_client.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
) -> None:
    """
    Initialize the ZMQ Publisher client class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(context, zmq.SocketType.PUB, address, bind, socket_ops)

publish(message) async

Publish a message. The topic will be set automatically based on the message type.

Parameters:

Name Type Description Default
message Message

Message to publish (must be a Message object)

required
Source code in aiperf/common/comms/zmq/pub_client.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
async def publish(self, message: Message) -> None:
    """Publish a message. The topic will be set automatically based on the message type.

    Args:
        message: Message to publish (must be a Message object)
    """
    await self._ensure_initialized()

    try:
        message_json = message.model_dump_json()

        # Publish message
        await self.socket.send_multipart(
            [message.message_type.encode(), message_json.encode()]
        )

    except (asyncio.CancelledError, zmq.ContextTerminated):
        self.trace(
            lambda: f"Pub client {self.client_id} cancelled or context terminated"
        )
        return

    except Exception as e:
        raise CommunicationError(
            f"Failed to publish message {message.message_type}: {e}",
        ) from e

aiperf.common.comms.zmq.pull_client

ZMQPullClient

Bases: BaseZMQClient, AsyncTaskManagerMixin

ZMQ PULL socket client for receiving work from PUSH sockets.

The PULL socket receives messages from PUSH sockets in a pipeline pattern, distributing work fairly among multiple PULL workers.

ASCII Diagram: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ PUSH │ │ PULL │ │ PULL │ │ (Producer) │ │ (Worker 1) │ │ (Worker 2) │ │ │ └─────────────┘ └─────────────┘ │ Tasks: │ ▲ ▲ │ - Task A │─────────────┘ │ │ - Task B │───────────────────────────────────┘ │ - Task C │─────────────┐ │ - Task D │ ▼ └─────────────┘ ┌─────────────┐ │ PULL │ │ (Worker N) │ └─────────────┘

Usage Pattern: - PULL receives work from multiple PUSH producers - Work is fairly distributed among PULL workers - Pipeline pattern for distributed processing - Each message is delivered to exactly one PULL socket

PULL/PUSH is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQPushPullProxy for more details.

Source code in aiperf/common/comms/zmq/pull_client.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@CommunicationClientFactory.register(CommunicationClientType.PULL)
class ZMQPullClient(BaseZMQClient, AsyncTaskManagerMixin):
    """
    ZMQ PULL socket client for receiving work from PUSH sockets.

    The PULL socket receives messages from PUSH sockets in a pipeline pattern,
    distributing work fairly among multiple PULL workers.

    ASCII Diagram:
    ┌─────────────┐      ┌─────────────┐      ┌─────────────┐
    │    PUSH     │      │    PULL     │      │    PULL     │
    │ (Producer)  │      │ (Worker 1)  │      │ (Worker 2)  │
    │             │      └─────────────┘      └─────────────┘
    │   Tasks:    │             ▲                     ▲
    │   - Task A  │─────────────┘                     │
    │   - Task B  │───────────────────────────────────┘
    │   - Task C  │─────────────┐
    │   - Task D  │             ▼
    └─────────────┘      ┌─────────────┐
                         │    PULL     │
                         │ (Worker N)  │
                         └─────────────┘

    Usage Pattern:
    - PULL receives work from multiple PUSH producers
    - Work is fairly distributed among PULL workers
    - Pipeline pattern for distributed processing
    - Each message is delivered to exactly one PULL socket

    PULL/PUSH is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQPushPullProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        max_concurrency: int | None = None,
    ) -> None:
        """
        Initialize the ZMQ Puller class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
            max_concurrency (int, optional): The maximum number of concurrent requests to allow.
        """
        super().__init__(context, zmq.SocketType.PULL, address, bind, socket_ops)
        self._pull_callbacks: dict[
            MessageTypeT, Callable[[Message], Coroutine[Any, Any, None]]
        ] = {}

        if max_concurrency is not None:
            self.semaphore = asyncio.Semaphore(value=max_concurrency)
        else:
            self.semaphore = asyncio.Semaphore(
                value=int(os.getenv("AIPERF_WORKER_CONCURRENT_REQUESTS", 500))
            )

    @aiperf_task
    async def _pull_receiver(self) -> None:
        """Background task for receiving data from the pull socket.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for messages from the socket and handle them.
        """
        if not self.is_initialized:
            await self.initialized_event.wait()

        while not self.stop_event.is_set():
            try:
                # acquire the semaphore to limit the number of concurrent requests
                # NOTE: This MUST be done BEFORE calling recv_string() to allow the zmq push/pull
                # logic to properly load balance the requests.
                await self.semaphore.acquire()

                message_json = await self.socket.recv_string()
                self.trace(
                    lambda msg=message_json: f"Received message from pull socket: {msg}"
                )
                self.execute_async(self._process_message(message_json))

            except zmq.Again:
                self.semaphore.release()  # release the semaphore as it was not used
                await yield_to_event_loop()
                continue

            except (asyncio.CancelledError, zmq.ContextTerminated):
                self.semaphore.release()  # release the semaphore as it was not used
                break

            except Exception as e:
                self.exception(f"Exception receiving data from pull socket: {e}")
                # Sleep for a short time to allow the system to potentially recover
                # if there are temporary issues.
                await asyncio.sleep(0.1)

    @on_stop
    async def _stop(self) -> None:
        """Wait for all tasks to complete."""
        await self.cancel_all_tasks()

    async def _process_message(self, message_json: str) -> None:
        """Process a message from the pull socket.

        This method is called by the background task when a message is received from
        the pull socket. It will deserialize the message and call the appropriate
        callback function.
        """
        try:
            message = Message.from_json(message_json)

            # Call callbacks with Message object
            if message.message_type in self._pull_callbacks:
                await self._pull_callbacks[message.message_type](message)
            else:
                self.warning(
                    lambda message_type=message.message_type: f"Pull message received for message type {message_type} without callback"
                )
        finally:
            # always release the semaphore to allow receiving more messages
            self.semaphore.release()

    async def register_pull_callback(
        self,
        message_type: MessageTypeT,
        callback: Callable[[Message], Coroutine[Any, Any, None]],
        max_concurrency: int | None = None,
    ) -> None:
        """Register a ZMQ Pull data callback for a given message type.

        Note that only one callback can be registered for a given message type.

        Args:
            message_type: The message type to register the callback for.
            callback: The function to call when data is received.
            max_concurrency: The maximum number of concurrent requests to allow.
        Raises:
            CommunicationError: If the client is not initialized
        """
        await self._ensure_initialized()

        # Register callback
        if message_type not in self._pull_callbacks:
            self._pull_callbacks[message_type] = callback
        else:
            raise ValueError(
                f"Callback already registered for message type {message_type}"
            )

        if max_concurrency is not None:
            self.semaphore = asyncio.Semaphore(value=max_concurrency)

__init__(context, address, bind, socket_ops=None, max_concurrency=None)

Initialize the ZMQ Puller class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
max_concurrency int

The maximum number of concurrent requests to allow.

None
Source code in aiperf/common/comms/zmq/pull_client.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    max_concurrency: int | None = None,
) -> None:
    """
    Initialize the ZMQ Puller class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
        max_concurrency (int, optional): The maximum number of concurrent requests to allow.
    """
    super().__init__(context, zmq.SocketType.PULL, address, bind, socket_ops)
    self._pull_callbacks: dict[
        MessageTypeT, Callable[[Message], Coroutine[Any, Any, None]]
    ] = {}

    if max_concurrency is not None:
        self.semaphore = asyncio.Semaphore(value=max_concurrency)
    else:
        self.semaphore = asyncio.Semaphore(
            value=int(os.getenv("AIPERF_WORKER_CONCURRENT_REQUESTS", 500))
        )

register_pull_callback(message_type, callback, max_concurrency=None) async

Register a ZMQ Pull data callback for a given message type.

Note that only one callback can be registered for a given message type.

Parameters:

Name Type Description Default
message_type MessageTypeT

The message type to register the callback for.

required
callback Callable[[Message], Coroutine[Any, Any, None]]

The function to call when data is received.

required
max_concurrency int | None

The maximum number of concurrent requests to allow.

None

Raises: CommunicationError: If the client is not initialized

Source code in aiperf/common/comms/zmq/pull_client.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
async def register_pull_callback(
    self,
    message_type: MessageTypeT,
    callback: Callable[[Message], Coroutine[Any, Any, None]],
    max_concurrency: int | None = None,
) -> None:
    """Register a ZMQ Pull data callback for a given message type.

    Note that only one callback can be registered for a given message type.

    Args:
        message_type: The message type to register the callback for.
        callback: The function to call when data is received.
        max_concurrency: The maximum number of concurrent requests to allow.
    Raises:
        CommunicationError: If the client is not initialized
    """
    await self._ensure_initialized()

    # Register callback
    if message_type not in self._pull_callbacks:
        self._pull_callbacks[message_type] = callback
    else:
        raise ValueError(
            f"Callback already registered for message type {message_type}"
        )

    if max_concurrency is not None:
        self.semaphore = asyncio.Semaphore(value=max_concurrency)

aiperf.common.comms.zmq.push_client

MAX_PUSH_RETRIES = 2 module-attribute

Maximum number of retries for pushing a message.

RETRY_DELAY_INTERVAL_SEC = 0.1 module-attribute

The interval to wait before retrying to push a message.

ZMQPushClient

Bases: BaseZMQClient, AsyncTaskManagerMixin

ZMQ PUSH socket client for sending work to PULL sockets.

The PUSH socket sends messages to PULL sockets in a pipeline pattern, distributing work fairly among available PULL workers.

ASCII Diagram: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ PUSH │ │ PULL │ │ PULL │ │ (Producer) │ │ (Worker 1) │ │ (Worker 2) │ │ │ └─────────────┘ └─────────────┘ │ Tasks: │ ▲ ▲ │ - Task A │─────────────┘ │ │ - Task B │───────────────────────────────────┘ │ - Task C │─────────────┐ │ - Task D │ ▼ └─────────────┘ ┌─────────────┐ │ PULL │ │ (Worker 3) │ └─────────────┘

Usage Pattern: - Round-robin distribution of work tasks (One-to-Many) - Each message delivered to exactly one worker - Pipeline pattern for distributed processing - Automatic load balancing across available workers

PUSH/PULL is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQPushPullProxy for more details.

Source code in aiperf/common/comms/zmq/push_client.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@CommunicationClientFactory.register(CommunicationClientType.PUSH)
class ZMQPushClient(BaseZMQClient, AsyncTaskManagerMixin):
    """
    ZMQ PUSH socket client for sending work to PULL sockets.

    The PUSH socket sends messages to PULL sockets in a pipeline pattern,
    distributing work fairly among available PULL workers.

    ASCII Diagram:
    ┌─────────────┐      ┌─────────────┐      ┌─────────────┐
    │    PUSH     │      │    PULL     │      │    PULL     │
    │ (Producer)  │      │ (Worker 1)  │      │ (Worker 2)  │
    │             │      └─────────────┘      └─────────────┘
    │   Tasks:    │             ▲                     ▲
    │   - Task A  │─────────────┘                     │
    │   - Task B  │───────────────────────────────────┘
    │   - Task C  │─────────────┐
    │   - Task D  │             ▼
    └─────────────┘      ┌─────────────┐
                         │    PULL     │
                         │ (Worker 3)  │
                         └─────────────┘

    Usage Pattern:
    - Round-robin distribution of work tasks (One-to-Many)
    - Each message delivered to exactly one worker
    - Pipeline pattern for distributed processing
    - Automatic load balancing across available workers

    PUSH/PULL is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQPushPullProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
    ) -> None:
        """
        Initialize the ZMQ Push client class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(context, zmq.SocketType.PUSH, address, bind, socket_ops)

    async def _push_message(
        self,
        message: Message,
        retry_count: int = 0,
        max_retries: int = MAX_PUSH_RETRIES,
    ) -> None:
        """Push a message to the socket. Will retry up to max_retries times.

        Args:
            message: Message to be sent must be a Message object
            retry_count: Current retry count
            max_retries: Maximum number of times to retry pushing the message
        """
        try:
            data_json = message.model_dump_json()
            await self.socket.send_string(data_json)
            self.trace(lambda msg=data_json: f"Pushed json data: {msg}")
        except (asyncio.CancelledError, zmq.ContextTerminated):
            return
        except zmq.Again as e:
            if retry_count >= max_retries:
                raise CommunicationError(
                    f"Failed to push data after {retry_count} retries: {e}",
                ) from e

            await asyncio.sleep(RETRY_DELAY_INTERVAL_SEC)
            return await self._push_message(message, retry_count + 1, max_retries)
        except Exception as e:
            raise CommunicationError(f"Failed to push data: {e}") from e

    async def push(self, message: Message) -> None:
        """Push data to a target. The message will be routed automatically
        based on the message type.

        Args:
            message: Message to be sent must be a Message object
        """
        await self._ensure_initialized()

        await self._push_message(message)

__init__(context, address, bind, socket_ops=None)

Initialize the ZMQ Push client class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/push_client.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
) -> None:
    """
    Initialize the ZMQ Push client class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(context, zmq.SocketType.PUSH, address, bind, socket_ops)

push(message) async

Push data to a target. The message will be routed automatically based on the message type.

Parameters:

Name Type Description Default
message Message

Message to be sent must be a Message object

required
Source code in aiperf/common/comms/zmq/push_client.py
103
104
105
106
107
108
109
110
111
112
async def push(self, message: Message) -> None:
    """Push data to a target. The message will be routed automatically
    based on the message type.

    Args:
        message: Message to be sent must be a Message object
    """
    await self._ensure_initialized()

    await self._push_message(message)

aiperf.common.comms.zmq.router_reply_client

ZMQRouterReplyClient

Bases: BaseZMQClient, AsyncTaskManagerMixin

ZMQ ROUTER socket client for handling requests from DEALER clients.

The ROUTER socket receives requests from DEALER clients and sends responses back to the originating DEALER client using routing envelopes.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ DEALER │───── Request ─────>│ │ │ (Client) │<──── Response ─────│ │ └──────────────┘ │ │ ┌──────────────┐ │ ROUTER │ │ DEALER │───── Request ─────>│ (Service) │ │ (Client) │<──── Response ─────│ │ └──────────────┘ │ │ ┌──────────────┐ │ │ │ DEALER │───── Request ─────>│ │ │ (Client) │<──── Response ─────│ │ └──────────────┘ └──────────────┘

Usage Pattern: - ROUTER handles requests from multiple DEALER clients - Maintains routing envelopes to send responses back - Many-to-one request handling pattern - Supports concurrent request processing

ROUTER/DEALER is a Many-to-One communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQDealerRouterProxy for more details.

Source code in aiperf/common/comms/zmq/router_reply_client.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@CommunicationClientFactory.register(CommunicationClientType.REPLY)
class ZMQRouterReplyClient(BaseZMQClient, AsyncTaskManagerMixin):
    """
    ZMQ ROUTER socket client for handling requests from DEALER clients.

    The ROUTER socket receives requests from DEALER clients and sends responses
    back to the originating DEALER client using routing envelopes.

    ASCII Diagram:
    ┌──────────────┐                    ┌──────────────┐
    │    DEALER    │───── Request ─────>│              │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    │              │
    ┌──────────────┐                    │    ROUTER    │
    │    DEALER    │───── Request ─────>│  (Service)   │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    │              │
    ┌──────────────┐                    │              │
    │    DEALER    │───── Request ─────>│              │
    │   (Client)   │<──── Response ─────│              │
    └──────────────┘                    └──────────────┘

    Usage Pattern:
    - ROUTER handles requests from multiple DEALER clients
    - Maintains routing envelopes to send responses back
    - Many-to-one request handling pattern
    - Supports concurrent request processing

    ROUTER/DEALER is a Many-to-One communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQDealerRouterProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
    ) -> None:
        """
        Initialize the ZMQ Router (Rep) client class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(context, zmq.SocketType.ROUTER, address, bind, socket_ops)

        self._request_handlers: dict[
            MessageTypeT,
            tuple[str, Callable[[Message], Coroutine[Any, Any, Message | None]]],
        ] = {}
        self._response_futures: dict[str, asyncio.Future[Message | None]] = {}

    @on_stop
    async def _on_stop(self) -> None:
        await self.cancel_all_tasks()

    @on_cleanup
    async def _cleanup(self) -> None:
        self._request_handlers.clear()

    def register_request_handler(
        self,
        service_id: str,
        message_type: MessageTypeT,
        handler: Callable[[Message], Coroutine[Any, Any, Message | None]],
    ) -> None:
        """Register a request handler. Anytime a request is received that matches the
        message type, the handler will be called. The handler should return a response
        message. If the handler returns None, the request will be ignored.

        Note that there is a limit of 1 to 1 mapping between message type and handler.

        Args:
            service_id: The service ID to register the handler for
            message_type: The message type to register the handler for
            handler: The handler to register
        """
        if message_type in self._request_handlers:
            raise ValueError(
                f"Handler already registered for message type {message_type}"
            )

        self.debug(
            lambda service_id=service_id,
            type=message_type: f"Registering request handler for {service_id} with message type {type}"
        )
        self._request_handlers[message_type] = (service_id, handler)

    async def _handle_request(self, request_id: str, request: Message) -> None:
        """Handle a request.

        This method will:
        - Parse the request JSON to create a Message object
        - Call the handler for the message type
        - Set the response future
        """
        message_type = request.message_type

        try:
            _, handler = self._request_handlers[message_type]
            response = await handler(request)

        except Exception as e:
            self.exception(f"Exception calling handler for {message_type}: {e}")
            response = ErrorMessage(
                request_id=request_id,
                error=ErrorDetails.from_exception(e),
            )

        try:
            self._response_futures[request_id].set_result(response)
        except Exception as e:
            self.exception(
                f"Exception setting response future for request {request_id}: {e}"
            )

    async def _wait_for_response(
        self, request_id: str, routing_envelope: tuple[bytes, ...]
    ) -> None:
        """Wait for a response to a request.

        This method will wait for the response future to be set and then send the response
        back to the client.
        """
        try:
            # Wait for the response asynchronously.
            response = await self._response_futures[request_id]

            if response is None:
                self.warning(
                    lambda req_id=request_id: f"Got None as response for request {req_id}"
                )
                response = ErrorMessage(
                    request_id=request_id,
                    error=ErrorDetails(
                        type="NO_RESPONSE",
                        message="No response was generated for the request.",
                    ),
                )

            self._response_futures.pop(request_id, None)

            # Send the response back to the client.
            await self.socket.send_multipart(
                [*routing_envelope, response.model_dump_json().encode()]
            )
        except Exception as e:
            self.exception(
                f"Exception waiting for response for request {request_id}: {e}"
            )

    @aiperf_task
    async def _rep_router_receiver(self) -> None:
        """Background task for receiving requests and sending responses.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for requests from the socket and send responses in
        an asynchronous manner.
        """
        self.debug("Waiting for router reply client to be initialized")
        if not self.is_initialized:
            await self.initialized_event.wait()

        self.debug("Router reply client initialized")

        while not self.stop_event.is_set():
            try:
                # Receive request
                try:
                    data = await self.socket.recv_multipart()
                    self.trace(lambda msg=data: f"Received request: {msg}")

                    request = Message.from_json(data[-1])
                    if not request.request_id:
                        self.exception(f"Request ID is missing from request: {data}")
                        continue

                    routing_envelope: tuple[bytes, ...] = (
                        tuple(data[:-1])
                        if len(data) > 1
                        else (request.request_id.encode(),)
                    )
                except zmq.Again:
                    # This means we timed out waiting for a request.
                    # We can continue to the next iteration of the loop.
                    await yield_to_event_loop()
                    continue

                # Create a new response future for this request that will be resolved
                # when the handler returns a response.
                self._response_futures[request.request_id] = asyncio.Future()
                # Handle the request in a new task.
                self.execute_async(self._handle_request(request.request_id, request))
                self.execute_async(
                    self._wait_for_response(request.request_id, routing_envelope)
                )

            except asyncio.CancelledError:
                self.trace(lambda: "Router reply client receiver task cancelled")
                break
            except Exception as e:
                self.exception(f"Exception receiving request: {e}")
                # Sleep for a short time to allow the system to potentially recover
                # if there are temporary issues.
                await asyncio.sleep(0.1)

__init__(context, address, bind, socket_ops=None)

Initialize the ZMQ Router (Rep) client class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/router_reply_client.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
) -> None:
    """
    Initialize the ZMQ Router (Rep) client class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(context, zmq.SocketType.ROUTER, address, bind, socket_ops)

    self._request_handlers: dict[
        MessageTypeT,
        tuple[str, Callable[[Message], Coroutine[Any, Any, Message | None]]],
    ] = {}
    self._response_futures: dict[str, asyncio.Future[Message | None]] = {}

register_request_handler(service_id, message_type, handler)

Register a request handler. Anytime a request is received that matches the message type, the handler will be called. The handler should return a response message. If the handler returns None, the request will be ignored.

Note that there is a limit of 1 to 1 mapping between message type and handler.

Parameters:

Name Type Description Default
service_id str

The service ID to register the handler for

required
message_type MessageTypeT

The message type to register the handler for

required
handler Callable[[Message], Coroutine[Any, Any, Message | None]]

The handler to register

required
Source code in aiperf/common/comms/zmq/router_reply_client.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def register_request_handler(
    self,
    service_id: str,
    message_type: MessageTypeT,
    handler: Callable[[Message], Coroutine[Any, Any, Message | None]],
) -> None:
    """Register a request handler. Anytime a request is received that matches the
    message type, the handler will be called. The handler should return a response
    message. If the handler returns None, the request will be ignored.

    Note that there is a limit of 1 to 1 mapping between message type and handler.

    Args:
        service_id: The service ID to register the handler for
        message_type: The message type to register the handler for
        handler: The handler to register
    """
    if message_type in self._request_handlers:
        raise ValueError(
            f"Handler already registered for message type {message_type}"
        )

    self.debug(
        lambda service_id=service_id,
        type=message_type: f"Registering request handler for {service_id} with message type {type}"
    )
    self._request_handlers[message_type] = (service_id, handler)

aiperf.common.comms.zmq.sub_client

ZMQSubClient

Bases: BaseZMQClient, AsyncTaskManagerMixin

ZMQ SUB socket client for subscribing to messages from PUB sockets. One-to-Many or Many-to-One communication pattern.

ASCII Diagram: ┌──────────────┐ ┌──────────────┐ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ │ SUB │ ┌──────────────┐ │ (Subscriber) │ │ PUB │───>│ │ │ (Publisher) │ │ │ └──────────────┘ └──────────────┘ OR ┌──────────────┐ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ │ PUB │ └──────────────┘ │ (Publisher) │ ┌──────────────┐ │ │───>│ SUB │ │ │ │ (Subscriber) │ └──────────────┘ └──────────────┘

Usage Pattern: - Single SUB socket subscribes to multiple PUB publishers (One-to-Many) OR - Multiple SUB sockets subscribe to a single PUB publisher (Many-to-One)

  • Subscribes to specific message topics/types
  • Receives all messages matching subscriptions

SUB/PUB is a One-to-Many communication pattern. If you need Many-to-Many, use a ZMQ Proxy as well. see :class:ZMQXPubXSubProxy for more details.

Source code in aiperf/common/comms/zmq/sub_client.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@CommunicationClientFactory.register(CommunicationClientType.SUB)
class ZMQSubClient(BaseZMQClient, AsyncTaskManagerMixin):
    """
    ZMQ SUB socket client for subscribing to messages from PUB sockets.
    One-to-Many or Many-to-One communication pattern.

    ASCII Diagram:
    ┌──────────────┐    ┌──────────────┐
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    │     SUB      │
    ┌──────────────┐    │ (Subscriber) │
    │     PUB      │───>│              │
    │ (Publisher)  │    │              │
    └──────────────┘    └──────────────┘
    OR
    ┌──────────────┐    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    │     PUB      │    └──────────────┘
    │ (Publisher)  │    ┌──────────────┐
    │              │───>│     SUB      │
    │              │    │ (Subscriber) │
    └──────────────┘    └──────────────┘


    Usage Pattern:
    - Single SUB socket subscribes to multiple PUB publishers (One-to-Many)
    OR
    - Multiple SUB sockets subscribe to a single PUB publisher (Many-to-One)

    - Subscribes to specific message topics/types
    - Receives all messages matching subscriptions

    SUB/PUB is a One-to-Many communication pattern. If you need Many-to-Many,
    use a ZMQ Proxy as well. see :class:`ZMQXPubXSubProxy` for more details.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
    ) -> None:
        """
        Initialize the ZMQ Subscriber class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to bind or connect the socket.
            socket_ops (dict, optional): Additional socket options to set.
        """
        super().__init__(context, zmq.SocketType.SUB, address, bind, socket_ops)

        self._subscribers: dict[MessageTypeT, list[Callable[[Message], Any]]] = {}

    @on_stop
    async def _on_stop(self) -> None:
        await self.cancel_all_tasks()

    async def subscribe_all(
        self, message_callback_map: dict[MessageTypeT, Callable[[Message], Any]]
    ) -> None:
        """Subscribe to all message_types in the map."""
        await self._ensure_initialized()
        for message_type, callback in message_callback_map.items():
            await self._subscribe_internal(message_type, callback)
        # TODO: HACK: This is a hack to ensure that the subscriptions are registered
        # since we do not have any confirmation from the server that the subscriptions
        # are registered, yet.
        await asyncio.sleep(0.1)

    async def subscribe(
        self, message_type: MessageTypeT, callback: Callable[[Message], Any]
    ) -> None:
        """Subscribe to a message_type.

        Args:
            message_type: MessageTypeT to subscribe to
            callback: Function to call when a message is received (receives Message object)

        Raises:
            Exception if subscription was not successful, None otherwise
        """
        await self._ensure_initialized()
        await self._subscribe_internal(message_type, callback)
        # TODO: HACK: This is a hack to ensure that the subscriptions are registered
        # since we do not have any confirmation from the server that the subscriptions
        # are registered, yet.
        await asyncio.sleep(0.1)

    async def _subscribe_internal(
        self, message_type: MessageTypeT, callback: Callable[[Message], Any]
    ) -> None:
        """Subscribe to a message_type.

        Args:
            message_type: MessageTypeT to subscribe to
            callback: Function to call when a message is received (receives Message object)
        """
        try:
            # Only subscribe to message_type if this is the first callback for this type
            if message_type not in self._subscribers:
                self.socket.subscribe(message_type.encode())
                self._subscribers[message_type] = []

            # Register callback
            self._subscribers[message_type].append(callback)

            self.trace(
                lambda: f"Subscribed to message_type: {message_type}, {self._subscribers[message_type]}"
            )

        except Exception as e:
            self.exception(f"Exception subscribing to message_type {message_type}: {e}")
            raise CommunicationError(
                f"Failed to subscribe to message_type {message_type}: {e}",
            ) from e

    async def _handle_message(self, topic_bytes: bytes, message_bytes: bytes) -> None:
        """Handle a message from a subscribed message_type."""
        message_type = topic_bytes.decode()
        message_json = message_bytes.decode()
        self.trace(
            lambda: f"Received message from message_type: '{message_type}', message: {message_json}"
        )

        message = Message.from_json(message_json)

        # Call callbacks with the parsed message object
        if message_type in self._subscribers:
            with contextlib.suppress(Exception):  # Ignore errors, they will get logged
                await call_all_functions(self._subscribers[message_type], message)

    @aiperf_task
    async def _sub_receiver(self) -> None:
        """Background task for receiving messages from subscribed topics.

        This method is a coroutine that will run indefinitely until the client is
        shutdown. It will wait for messages from the socket and handle them.
        """
        if not self.is_initialized:
            self.trace("Sub client %s waiting for initialization", self.client_id)
            await self.initialized_event.wait()
            self.trace(lambda: f"Sub client {self.client_id} initialized")

        while not self.stop_event.is_set():
            try:
                (
                    topic_bytes,
                    message_bytes,
                ) = await self.socket.recv_multipart()

                self.execute_async(self._handle_message(topic_bytes, message_bytes))

            except (asyncio.CancelledError, zmq.ContextTerminated):
                self.trace(
                    lambda: f"Sub client {self.client_id} receiver task cancelled"
                )
                break

            except zmq.Again:
                await yield_to_event_loop()
                continue

            except Exception as e:
                self.exception(
                    f"Exception receiving message from subscription: {e}, {type(e)}"
                )
                # Sleep for a short time to allow the system to potentially recover
                # if there are temporary issues.
                await asyncio.sleep(0.1)

__init__(context, address, bind, socket_ops=None)

Initialize the ZMQ Subscriber class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to bind or connect the socket.

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/sub_client.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    context: zmq.asyncio.Context,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
) -> None:
    """
    Initialize the ZMQ Subscriber class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to bind or connect the socket.
        socket_ops (dict, optional): Additional socket options to set.
    """
    super().__init__(context, zmq.SocketType.SUB, address, bind, socket_ops)

    self._subscribers: dict[MessageTypeT, list[Callable[[Message], Any]]] = {}

subscribe(message_type, callback) async

Subscribe to a message_type.

Parameters:

Name Type Description Default
message_type MessageTypeT

MessageTypeT to subscribe to

required
callback Callable[[Message], Any]

Function to call when a message is received (receives Message object)

required
Source code in aiperf/common/comms/zmq/sub_client.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
async def subscribe(
    self, message_type: MessageTypeT, callback: Callable[[Message], Any]
) -> None:
    """Subscribe to a message_type.

    Args:
        message_type: MessageTypeT to subscribe to
        callback: Function to call when a message is received (receives Message object)

    Raises:
        Exception if subscription was not successful, None otherwise
    """
    await self._ensure_initialized()
    await self._subscribe_internal(message_type, callback)
    # TODO: HACK: This is a hack to ensure that the subscriptions are registered
    # since we do not have any confirmation from the server that the subscriptions
    # are registered, yet.
    await asyncio.sleep(0.1)

subscribe_all(message_callback_map) async

Subscribe to all message_types in the map.

Source code in aiperf/common/comms/zmq/sub_client.py
83
84
85
86
87
88
89
90
91
92
93
async def subscribe_all(
    self, message_callback_map: dict[MessageTypeT, Callable[[Message], Any]]
) -> None:
    """Subscribe to all message_types in the map."""
    await self._ensure_initialized()
    for message_type, callback in message_callback_map.items():
        await self._subscribe_internal(message_type, callback)
    # TODO: HACK: This is a hack to ensure that the subscriptions are registered
    # since we do not have any confirmation from the server that the subscriptions
    # are registered, yet.
    await asyncio.sleep(0.1)

aiperf.common.comms.zmq.zmq_base_client

BaseZMQClient

Bases: AIPerfTaskMixin, AIPerfLoggerMixin

Base class for all ZMQ clients. It can be used as-is to create a new ZMQ client, or it can be subclassed to create specific ZMQ client functionality.

It inherits from the :class:AIPerfTaskMixin, allowing derived classes to implement specific hooks.

Source code in aiperf/common/comms/zmq/zmq_base_client.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@supports_hooks(
    AIPerfHook.ON_INIT,
    AIPerfHook.ON_STOP,
    AIPerfHook.ON_CLEANUP,
    AIPerfTaskHook.AIPERF_TASK,
)
class BaseZMQClient(AIPerfTaskMixin, AIPerfLoggerMixin):
    """Base class for all ZMQ clients. It can be used as-is to create a new ZMQ client,
    or it can be subclassed to create specific ZMQ client functionality.

    It inherits from the :class:`AIPerfTaskMixin`, allowing derived
    classes to implement specific hooks.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        socket_type: zmq.SocketType,
        address: str,
        bind: bool,
        socket_ops: dict | None = None,
        client_id: str | None = None,
    ) -> None:
        """
        Initialize the ZMQ Base class.

        Args:
            context (zmq.asyncio.Context): The ZMQ context.
            address (str): The address to bind or connect to.
            bind (bool): Whether to BIND or CONNECT the socket.
            socket_type (SocketType): The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).
            socket_ops (dict, optional): Additional socket options to set.
        """
        self.stop_event: asyncio.Event = asyncio.Event()
        self.initialized_event: asyncio.Event = asyncio.Event()
        self.context: zmq.asyncio.Context = context
        self.address: str = address
        self.bind: bool = bind
        self.socket_type: zmq.SocketType = socket_type
        self._socket: zmq.asyncio.Socket | None = None
        self.socket_ops: dict = socket_ops or {}
        self.client_id: str = (
            client_id
            or f"{self.socket_type.name.lower()}_client_{uuid.uuid4().hex[:8]}"
        )
        super().__init__(logger_name=self.client_id)
        self.trace(lambda: f"ZMQ client __init__: {self.client_id}")

    @property
    def is_initialized(self) -> bool:
        """Check if the client is initialized."""
        return self.initialized_event.is_set()

    @property
    def stop_requested(self) -> bool:
        """Check if the client has been requested to stop."""
        return self.stop_event.is_set()

    @property
    def socket_type_name(self) -> str:
        """Get the name of the socket type."""
        return self.socket_type.name

    @property
    def socket(self) -> zmq.asyncio.Socket:
        """Get the zmq socket for the client.

        Raises:
            CommunicationError: If the client is not initialized
        """
        if not self._socket:
            raise CommunicationError(
                "Communication channels are not initialized",
            )
        return self._socket

    async def _ensure_initialized(self) -> None:
        """Ensure the communication channels are initialized and not shutdown.

        If not initialized, it will automatically initialize.

        Raises:
            CommunicationError: If the communication channels are shutdown
        """
        if not self.is_initialized:
            await self.initialize()
        if self.stop_requested:
            raise asyncio.CancelledError()

    async def initialize(self) -> None:
        """Initialize the communication.

        This method will:
        - Create the zmq socket
        - Bind or connect the socket to the address
        - Set the socket options
        - Run the AIPerfHook.ON_INIT hooks
        """
        try:
            self._socket = self.context.socket(self.socket_type)
            if self.bind:
                self.debug(
                    lambda: f"ZMQ {self.socket_type_name} socket initialized, try BIND to {self.address} ({self.client_id})"
                )
                self._socket.bind(self.address)
            else:
                self.debug(
                    lambda: f"ZMQ {self.socket_type_name} socket initialized, try CONNECT to {self.address} ({self.client_id})"
                )
                self._socket.connect(self.address)

            # Set default timeouts
            self._socket.setsockopt(zmq.RCVTIMEO, ZMQSocketDefaults.RCVTIMEO)
            self._socket.setsockopt(zmq.SNDTIMEO, ZMQSocketDefaults.SNDTIMEO)

            # Set performance-oriented socket options
            self._socket.setsockopt(zmq.TCP_KEEPALIVE, ZMQSocketDefaults.TCP_KEEPALIVE)
            self._socket.setsockopt(
                zmq.TCP_KEEPALIVE_IDLE, ZMQSocketDefaults.TCP_KEEPALIVE_IDLE
            )
            self._socket.setsockopt(
                zmq.TCP_KEEPALIVE_INTVL, ZMQSocketDefaults.TCP_KEEPALIVE_INTVL
            )
            self._socket.setsockopt(
                zmq.TCP_KEEPALIVE_CNT, ZMQSocketDefaults.TCP_KEEPALIVE_CNT
            )
            self._socket.setsockopt(zmq.IMMEDIATE, ZMQSocketDefaults.IMMEDIATE)
            self._socket.setsockopt(zmq.LINGER, ZMQSocketDefaults.LINGER)

            # Set additional socket options requested by the caller
            for key, val in self.socket_ops.items():
                self._socket.setsockopt(key, val)

            await self.run_hooks(AIPerfHook.ON_INIT)

            self.initialized_event.set()
            self.debug(
                lambda: f"ZMQ {self.socket_type_name} socket {'BOUND' if self.bind else 'CONNECTED'} to {self.address} ({self.client_id})"
            )

        except AIPerfError:
            raise  # re-raise it up the stack
        except Exception as e:
            raise InitializationError(f"Failed to initialize ZMQ socket: {e}") from e

    async def shutdown(self) -> None:
        """Shutdown the communication.

        This method will:
        - Close the zmq socket
        - Run the AIPerfHook.ON_CLEANUP hooks
        """
        if self.stop_requested:
            return

        self.stop_event.set()

        try:
            await self.run_hooks(AIPerfHook.ON_STOP)
        except AIPerfError:
            raise  # re-raise it up the stack
        except Exception as e:
            self.exception(
                f"Uncaught exception running ON_STOP hooks: {e} ({self.client_id})"
            )

        try:
            await self.run_hooks(AIPerfHook.ON_CLEANUP)
        except AIPerfError:
            raise  # re-raise it up the stack
        except Exception as e:
            self.exception(
                f"Uncaught exception cleaning up ZMQ socket: {e} ({self.client_id})"
            )

        finally:
            try:
                if self._socket:
                    self._socket.close()
            except zmq.ContextTerminated:
                self.debug(
                    lambda: f"ZMQ context already terminated, skipping socket close ({self.client_id})"
                )
                return
            except AIPerfError:
                raise  # re-raise it up the stack
            except Exception as e:
                self.exception(
                    f"Uncaught exception shutting down ZMQ socket: {e} ({self.client_id})"
                )
            finally:
                self._socket = None

is_initialized property

Check if the client is initialized.

socket property

Get the zmq socket for the client.

Raises:

Type Description
CommunicationError

If the client is not initialized

socket_type_name property

Get the name of the socket type.

stop_requested property

Check if the client has been requested to stop.

__init__(context, socket_type, address, bind, socket_ops=None, client_id=None)

Initialize the ZMQ Base class.

Parameters:

Name Type Description Default
context Context

The ZMQ context.

required
address str

The address to bind or connect to.

required
bind bool

Whether to BIND or CONNECT the socket.

required
socket_type SocketType

The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).

required
socket_ops dict

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/zmq_base_client.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    context: zmq.asyncio.Context,
    socket_type: zmq.SocketType,
    address: str,
    bind: bool,
    socket_ops: dict | None = None,
    client_id: str | None = None,
) -> None:
    """
    Initialize the ZMQ Base class.

    Args:
        context (zmq.asyncio.Context): The ZMQ context.
        address (str): The address to bind or connect to.
        bind (bool): Whether to BIND or CONNECT the socket.
        socket_type (SocketType): The type of ZMQ socket (eg. PUB, SUB, ROUTER, DEALER, etc.).
        socket_ops (dict, optional): Additional socket options to set.
    """
    self.stop_event: asyncio.Event = asyncio.Event()
    self.initialized_event: asyncio.Event = asyncio.Event()
    self.context: zmq.asyncio.Context = context
    self.address: str = address
    self.bind: bool = bind
    self.socket_type: zmq.SocketType = socket_type
    self._socket: zmq.asyncio.Socket | None = None
    self.socket_ops: dict = socket_ops or {}
    self.client_id: str = (
        client_id
        or f"{self.socket_type.name.lower()}_client_{uuid.uuid4().hex[:8]}"
    )
    super().__init__(logger_name=self.client_id)
    self.trace(lambda: f"ZMQ client __init__: {self.client_id}")

initialize() async

Initialize the communication.

This method will: - Create the zmq socket - Bind or connect the socket to the address - Set the socket options - Run the AIPerfHook.ON_INIT hooks

Source code in aiperf/common/comms/zmq/zmq_base_client.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
async def initialize(self) -> None:
    """Initialize the communication.

    This method will:
    - Create the zmq socket
    - Bind or connect the socket to the address
    - Set the socket options
    - Run the AIPerfHook.ON_INIT hooks
    """
    try:
        self._socket = self.context.socket(self.socket_type)
        if self.bind:
            self.debug(
                lambda: f"ZMQ {self.socket_type_name} socket initialized, try BIND to {self.address} ({self.client_id})"
            )
            self._socket.bind(self.address)
        else:
            self.debug(
                lambda: f"ZMQ {self.socket_type_name} socket initialized, try CONNECT to {self.address} ({self.client_id})"
            )
            self._socket.connect(self.address)

        # Set default timeouts
        self._socket.setsockopt(zmq.RCVTIMEO, ZMQSocketDefaults.RCVTIMEO)
        self._socket.setsockopt(zmq.SNDTIMEO, ZMQSocketDefaults.SNDTIMEO)

        # Set performance-oriented socket options
        self._socket.setsockopt(zmq.TCP_KEEPALIVE, ZMQSocketDefaults.TCP_KEEPALIVE)
        self._socket.setsockopt(
            zmq.TCP_KEEPALIVE_IDLE, ZMQSocketDefaults.TCP_KEEPALIVE_IDLE
        )
        self._socket.setsockopt(
            zmq.TCP_KEEPALIVE_INTVL, ZMQSocketDefaults.TCP_KEEPALIVE_INTVL
        )
        self._socket.setsockopt(
            zmq.TCP_KEEPALIVE_CNT, ZMQSocketDefaults.TCP_KEEPALIVE_CNT
        )
        self._socket.setsockopt(zmq.IMMEDIATE, ZMQSocketDefaults.IMMEDIATE)
        self._socket.setsockopt(zmq.LINGER, ZMQSocketDefaults.LINGER)

        # Set additional socket options requested by the caller
        for key, val in self.socket_ops.items():
            self._socket.setsockopt(key, val)

        await self.run_hooks(AIPerfHook.ON_INIT)

        self.initialized_event.set()
        self.debug(
            lambda: f"ZMQ {self.socket_type_name} socket {'BOUND' if self.bind else 'CONNECTED'} to {self.address} ({self.client_id})"
        )

    except AIPerfError:
        raise  # re-raise it up the stack
    except Exception as e:
        raise InitializationError(f"Failed to initialize ZMQ socket: {e}") from e

shutdown() async

Shutdown the communication.

This method will: - Close the zmq socket - Run the AIPerfHook.ON_CLEANUP hooks

Source code in aiperf/common/comms/zmq/zmq_base_client.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
async def shutdown(self) -> None:
    """Shutdown the communication.

    This method will:
    - Close the zmq socket
    - Run the AIPerfHook.ON_CLEANUP hooks
    """
    if self.stop_requested:
        return

    self.stop_event.set()

    try:
        await self.run_hooks(AIPerfHook.ON_STOP)
    except AIPerfError:
        raise  # re-raise it up the stack
    except Exception as e:
        self.exception(
            f"Uncaught exception running ON_STOP hooks: {e} ({self.client_id})"
        )

    try:
        await self.run_hooks(AIPerfHook.ON_CLEANUP)
    except AIPerfError:
        raise  # re-raise it up the stack
    except Exception as e:
        self.exception(
            f"Uncaught exception cleaning up ZMQ socket: {e} ({self.client_id})"
        )

    finally:
        try:
            if self._socket:
                self._socket.close()
        except zmq.ContextTerminated:
            self.debug(
                lambda: f"ZMQ context already terminated, skipping socket close ({self.client_id})"
            )
            return
        except AIPerfError:
            raise  # re-raise it up the stack
        except Exception as e:
            self.exception(
                f"Uncaught exception shutting down ZMQ socket: {e} ({self.client_id})"
            )
        finally:
            self._socket = None

aiperf.common.comms.zmq.zmq_comms

BaseZMQCommunication

Bases: BaseCommunication, AIPerfLoggerMixin, ABC

ZeroMQ-based implementation of the Communication interface.

Uses ZeroMQ for publish/subscribe and request/reply patterns to facilitate communication between AIPerf components.

Source code in aiperf/common/comms/zmq/zmq_comms.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class BaseZMQCommunication(BaseCommunication, AIPerfLoggerMixin, ABC):
    """ZeroMQ-based implementation of the Communication interface.

    Uses ZeroMQ for publish/subscribe and request/reply patterns to
    facilitate communication between AIPerf components.
    """

    def __init__(
        self,
        config: BaseZMQCommunicationConfig,
    ) -> None:
        super().__init__()
        self.stop_event: asyncio.Event = asyncio.Event()
        self.initialized_event: asyncio.Event = asyncio.Event()
        self.config = config

        self.context = zmq.asyncio.Context.instance()
        self.clients: list[BaseZMQClient] = []

        self.debug(f"ZMQ communication using protocol: {type(self.config).__name__}")

    @property
    def is_initialized(self) -> bool:
        """Check if communication channels are initialized."""
        return self.initialized_event.is_set()

    @property
    def stop_requested(self) -> bool:
        """Check if the communication channels are being shutdown."""
        return self.stop_event.is_set()

    def get_address(self, address_type: CommunicationClientAddressType | str) -> str:
        """Get the actual address based on the address type from the config."""
        if isinstance(address_type, CommunicationClientAddressType):
            return self.config.get_address(address_type)
        return address_type

    async def initialize(self) -> None:
        """Initialize communication channels."""
        if self.is_initialized:
            return

        tasks = []
        for client in self.clients:
            if not client.is_initialized:
                tasks.append(asyncio.create_task(client.initialize()))

        await asyncio.gather(*tasks)
        self.initialized_event.set()

    async def shutdown(self) -> None:
        """Gracefully shutdown communication channels.

        This method will wait for all clients to shutdown before shutting down
        the context.

        Returns:
            True if shutdown was successful, False otherwise
        """
        if self.stop_event.is_set():
            return

        try:
            if not self.stop_event.is_set():
                self.stop_event.set()

            await asyncio.gather(
                *(
                    client.shutdown()
                    for client in self.clients
                    if client.is_initialized
                ),
            )

            self.context.term()

        except asyncio.CancelledError:
            self.debug("ZMQ communication shutdown cancelled")
            pass

        except asyncio.TimeoutError:
            self.debug("ZMQ communication shutdown timed out")
            pass

        except zmq.ContextTerminated:
            self.debug("ZMQ communication context already terminated")
            pass

        except Exception as e:
            raise ShutdownError(
                "Failed to shutdown ZMQ communication",
            ) from e

        finally:
            self.clients.clear()

    def create_client(
        self,
        client_type: CommunicationClientType,
        address: CommunicationClientAddressType | str,
        bind: bool = False,
        socket_ops: dict | None = None,
    ) -> CommunicationClientProtocol:
        """Create a communication client for a given client type and address.

        Args:
            client_type: The type of client to create.
            address: The type of address to use when looking up in the communication config, or the address itself.
            bind: Whether to bind or connect the socket.
            socket_ops: Additional socket options to set.
        """
        client = CommunicationClientFactory.create_instance(
            client_type,
            context=self.context,
            address=self.get_address(address),
            bind=bind,
            socket_ops=socket_ops,
        )

        self.clients.append(client)
        return client

is_initialized property

Check if communication channels are initialized.

stop_requested property

Check if the communication channels are being shutdown.

create_client(client_type, address, bind=False, socket_ops=None)

Create a communication client for a given client type and address.

Parameters:

Name Type Description Default
client_type CommunicationClientType

The type of client to create.

required
address CommunicationClientAddressType | str

The type of address to use when looking up in the communication config, or the address itself.

required
bind bool

Whether to bind or connect the socket.

False
socket_ops dict | None

Additional socket options to set.

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def create_client(
    self,
    client_type: CommunicationClientType,
    address: CommunicationClientAddressType | str,
    bind: bool = False,
    socket_ops: dict | None = None,
) -> CommunicationClientProtocol:
    """Create a communication client for a given client type and address.

    Args:
        client_type: The type of client to create.
        address: The type of address to use when looking up in the communication config, or the address itself.
        bind: Whether to bind or connect the socket.
        socket_ops: Additional socket options to set.
    """
    client = CommunicationClientFactory.create_instance(
        client_type,
        context=self.context,
        address=self.get_address(address),
        bind=bind,
        socket_ops=socket_ops,
    )

    self.clients.append(client)
    return client

get_address(address_type)

Get the actual address based on the address type from the config.

Source code in aiperf/common/comms/zmq/zmq_comms.py
61
62
63
64
65
def get_address(self, address_type: CommunicationClientAddressType | str) -> str:
    """Get the actual address based on the address type from the config."""
    if isinstance(address_type, CommunicationClientAddressType):
        return self.config.get_address(address_type)
    return address_type

initialize() async

Initialize communication channels.

Source code in aiperf/common/comms/zmq/zmq_comms.py
67
68
69
70
71
72
73
74
75
76
77
78
async def initialize(self) -> None:
    """Initialize communication channels."""
    if self.is_initialized:
        return

    tasks = []
    for client in self.clients:
        if not client.is_initialized:
            tasks.append(asyncio.create_task(client.initialize()))

    await asyncio.gather(*tasks)
    self.initialized_event.set()

shutdown() async

Gracefully shutdown communication channels.

This method will wait for all clients to shutdown before shutting down the context.

Returns:

Type Description
None

True if shutdown was successful, False otherwise

Source code in aiperf/common/comms/zmq/zmq_comms.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
async def shutdown(self) -> None:
    """Gracefully shutdown communication channels.

    This method will wait for all clients to shutdown before shutting down
    the context.

    Returns:
        True if shutdown was successful, False otherwise
    """
    if self.stop_event.is_set():
        return

    try:
        if not self.stop_event.is_set():
            self.stop_event.set()

        await asyncio.gather(
            *(
                client.shutdown()
                for client in self.clients
                if client.is_initialized
            ),
        )

        self.context.term()

    except asyncio.CancelledError:
        self.debug("ZMQ communication shutdown cancelled")
        pass

    except asyncio.TimeoutError:
        self.debug("ZMQ communication shutdown timed out")
        pass

    except zmq.ContextTerminated:
        self.debug("ZMQ communication context already terminated")
        pass

    except Exception as e:
        raise ShutdownError(
            "Failed to shutdown ZMQ communication",
        ) from e

    finally:
        self.clients.clear()

ZMQIPCCommunication

Bases: BaseZMQCommunication

ZeroMQ-based implementation of the Communication interface using IPC transport.

Source code in aiperf/common/comms/zmq/zmq_comms.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@CommunicationFactory.register(CommunicationBackend.ZMQ_IPC)
class ZMQIPCCommunication(BaseZMQCommunication):
    """ZeroMQ-based implementation of the Communication interface using IPC transport."""

    def __init__(self, config: ZMQIPCConfig | None = None) -> None:
        """Initialize ZMQ IPC communication.

        Args:
            config: ZMQIPCConfig object with configuration parameters
        """
        super().__init__(config or ZMQIPCConfig())
        # call after super init so that way self.config is set
        self._setup_ipc_directory()

    async def initialize(self) -> None:
        """Initialize communication channels.

        This method will create the IPC socket directory if needed.
        """
        await super().initialize()

    async def shutdown(self) -> None:
        """Gracefully shutdown communication channels.

        This method will wait for all clients to shutdown before shutting down
        the context.
        """
        await super().shutdown()
        self._cleanup_ipc_sockets()

    def _setup_ipc_directory(self) -> None:
        """Create IPC socket directory if using IPC transport."""
        self._ipc_socket_dir = Path(self.config.path)
        if not self._ipc_socket_dir.exists():
            self.debug(
                f"IPC socket directory does not exist, creating: {self._ipc_socket_dir}"
            )
            self._ipc_socket_dir.mkdir(parents=True, exist_ok=True)
            self.debug(f"Created IPC socket directory: {self._ipc_socket_dir}")
        else:
            self.debug(f"IPC socket directory already exists: {self._ipc_socket_dir}")

    def _cleanup_ipc_sockets(self) -> None:
        """Clean up IPC socket files."""
        if self._ipc_socket_dir and self._ipc_socket_dir.exists():
            # Remove all .ipc files in the directory
            ipc_files = glob.glob(str(self._ipc_socket_dir / "*.ipc"))
            for ipc_file in ipc_files:
                try:
                    if os.path.exists(ipc_file):
                        os.unlink(ipc_file)
                        self.debug(f"Removed IPC socket file: {ipc_file}")
                except OSError as e:
                    if e.errno != errno.ENOENT:
                        self.warning(
                            lambda ipc_file=ipc_file,
                            e=e: f"Failed to remove IPC socket file {ipc_file}: {e}"
                        )

__init__(config=None)

Initialize ZMQ IPC communication.

Parameters:

Name Type Description Default
config ZMQIPCConfig | None

ZMQIPCConfig object with configuration parameters

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
170
171
172
173
174
175
176
177
178
def __init__(self, config: ZMQIPCConfig | None = None) -> None:
    """Initialize ZMQ IPC communication.

    Args:
        config: ZMQIPCConfig object with configuration parameters
    """
    super().__init__(config or ZMQIPCConfig())
    # call after super init so that way self.config is set
    self._setup_ipc_directory()

initialize() async

Initialize communication channels.

This method will create the IPC socket directory if needed.

Source code in aiperf/common/comms/zmq/zmq_comms.py
180
181
182
183
184
185
async def initialize(self) -> None:
    """Initialize communication channels.

    This method will create the IPC socket directory if needed.
    """
    await super().initialize()

shutdown() async

Gracefully shutdown communication channels.

This method will wait for all clients to shutdown before shutting down the context.

Source code in aiperf/common/comms/zmq/zmq_comms.py
187
188
189
190
191
192
193
194
async def shutdown(self) -> None:
    """Gracefully shutdown communication channels.

    This method will wait for all clients to shutdown before shutting down
    the context.
    """
    await super().shutdown()
    self._cleanup_ipc_sockets()

ZMQTCPCommunication

Bases: BaseZMQCommunication

ZeroMQ-based implementation of the Communication interface using TCP transport.

Source code in aiperf/common/comms/zmq/zmq_comms.py
153
154
155
156
157
158
159
160
161
162
163
@CommunicationFactory.register(CommunicationBackend.ZMQ_TCP)
class ZMQTCPCommunication(BaseZMQCommunication):
    """ZeroMQ-based implementation of the Communication interface using TCP transport."""

    def __init__(self, config: ZMQTCPConfig | None = None) -> None:
        """Initialize ZMQ TCP communication.

        Args:
            config: ZMQTCPTransportConfig object with configuration parameters
        """
        super().__init__(config or ZMQTCPConfig())

__init__(config=None)

Initialize ZMQ TCP communication.

Parameters:

Name Type Description Default
config ZMQTCPConfig | None

ZMQTCPTransportConfig object with configuration parameters

None
Source code in aiperf/common/comms/zmq/zmq_comms.py
157
158
159
160
161
162
163
def __init__(self, config: ZMQTCPConfig | None = None) -> None:
    """Initialize ZMQ TCP communication.

    Args:
        config: ZMQTCPTransportConfig object with configuration parameters
    """
    super().__init__(config or ZMQTCPConfig())

aiperf.common.comms.zmq.zmq_defaults

ZMQSocketDefaults

Default values for ZMQ sockets.

Source code in aiperf/common/comms/zmq/zmq_defaults.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class ZMQSocketDefaults:
    """Default values for ZMQ sockets."""

    # Socket Options
    RCVTIMEO = 300000  # 5 minutes
    SNDTIMEO = 300000  # 5 minutes
    TCP_KEEPALIVE = 1
    TCP_KEEPALIVE_IDLE = 60
    TCP_KEEPALIVE_INTVL = 10
    TCP_KEEPALIVE_CNT = 3
    IMMEDIATE = 1  # Don't queue messages
    LINGER = 0  # Don't wait on close

aiperf.common.comms.zmq.zmq_proxy_base

BaseZMQProxy

Bases: AIPerfLoggerMixin, ABC

A Base ZMQ Proxy class.

  • Frontend and backend sockets forward messages bidirectionally
    • Frontend and Backend sockets both BIND
  • Multiple clients CONNECT to frontend_address
  • Multiple services CONNECT to backend_address
  • Control: Optional REP socket for proxy commands (start/stop/pause) - not implemented yet
  • Monitoring: Optional PUB socket that broadcasts copies of all forwarded messages - not implemented yet
  • Proxy runs in separate thread to avoid blocking main event loop
Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class BaseZMQProxy(AIPerfLoggerMixin, ABC):
    """
    A Base ZMQ Proxy class.

    - Frontend and backend sockets forward messages bidirectionally
        - Frontend and Backend sockets both BIND
    - Multiple clients CONNECT to `frontend_address`
    - Multiple services CONNECT to `backend_address`
    - Control: Optional REP socket for proxy commands (start/stop/pause) - not implemented yet
    - Monitoring: Optional PUB socket that broadcasts copies of all forwarded messages - not implemented yet
    - Proxy runs in separate thread to avoid blocking main event loop
    """

    def __init__(
        self,
        frontend_socket_class: type[BaseZMQClient],
        backend_socket_class: type[BaseZMQClient],
        context: zmq.asyncio.Context,
        zmq_proxy_config: BaseZMQProxyConfig,
        socket_ops: dict | None = None,
        proxy_uuid: str | None = None,
    ) -> None:
        """Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

        Args:
            frontend_socket_class (type[BaseZMQClient]): The frontend socket class.
            backend_socket_class (type[BaseZMQClient]): The backend socket class.
            context (zmq.asyncio.Context): The ZMQ context.
            zmq_proxy_config (BaseZMQProxyConfig): The ZMQ proxy configuration.
            socket_ops (dict, optional): Additional socket options to set.
            proxy_uuid (str, optional): An optional UUID for the proxy instance. If not provided,
                a new UUID will be generated. This is useful for tracing and debugging purposes.
        """

        self.proxy_uuid = proxy_uuid or uuid.uuid4().hex[:8]
        self.proxy_id = f"{self.__class__.__name__.lower()}_{self.proxy_uuid}"
        super().__init__()
        self.context = context
        self.socket_ops = socket_ops

        self.monitor_task: asyncio.Task | None = None
        self.proxy_task: asyncio.Task | None = None
        self.control_client: ProxySocketClient | None = None
        self.capture_client: ProxySocketClient | None = None

        self.frontend_address = zmq_proxy_config.frontend_address
        self.backend_address = zmq_proxy_config.backend_address
        self.control_address = zmq_proxy_config.control_address
        self.capture_address = zmq_proxy_config.capture_address

        self.debug(
            lambda: f"Proxy Initializing - Frontend: {self.frontend_address}, Backend: {self.backend_address}"
        )

        self.backend_socket = backend_socket_class(
            context=self.context,
            address=self.backend_address,
            socket_ops=self.socket_ops,
            proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
        )

        self.frontend_socket = frontend_socket_class(
            context=self.context,
            address=self.frontend_address,
            socket_ops=self.socket_ops,
            proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
        )

        if self.control_address:
            self.debug(lambda: f"Proxy Control - Address: {self.control_address}")
            self.control_client = ProxySocketClient(
                context=self.context,
                socket_type=SocketType.REP,
                address=self.control_address,
                socket_ops=self.socket_ops,
                end_type=ProxyEndType.Control,
                proxy_uuid=self.proxy_uuid,
            )

        if self.capture_address:
            self.debug(lambda: f"Proxy Capture - Address: {self.capture_address}")
            self.capture_client = ProxySocketClient(
                context=self.context,
                socket_type=SocketType.PUB,
                address=self.capture_address,
                socket_ops=self.socket_ops,
                end_type=ProxyEndType.Capture,
                proxy_uuid=self.proxy_uuid,
            )

    @classmethod
    @abstractmethod
    def from_config(
        cls,
        config: BaseZMQProxyConfig | None,
        socket_ops: dict | None = None,
    ) -> "BaseZMQProxy | None":
        """Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided."""
        ...

    async def _initialize(self) -> None:
        """Initialize and start the BaseZMQProxy."""
        self.debug("Proxy Initializing Sockets...")
        self.debug(
            lambda: f"Frontend {self.frontend_socket.socket_type.name} socket binding to: {self.frontend_address} (for {self.backend_socket.socket_type.name} clients)"
        )
        self.debug(
            lambda: f"Backend {self.backend_socket.socket_type.name} socket binding to: {self.backend_address} (for {self.frontend_socket.socket_type.name} services)"
        )
        if hasattr(self.backend_socket, "proxy_id"):
            self.debug(
                lambda: f"Backend socket identity: {self.backend_socket.proxy_id}"
            )

        try:
            await asyncio.gather(
                self.backend_socket.initialize(),
                self.frontend_socket.initialize(),
                *[
                    client.initialize()
                    for client in [self.control_client, self.capture_client]
                    if client
                ],
            )

            self.debug("Proxy Sockets Initialized Successfully")

            if self.control_client:
                self.debug(lambda: f"Control socket bound to: {self.control_address}")
            if self.capture_client:
                self.debug(lambda: f"Capture socket bound to: {self.capture_address}")

        except Exception as e:
            self.exception(f"Proxy Socket Initialization Failed: {e}")
            raise

    async def stop(self) -> None:
        """Shutdown the BaseZMQProxy."""
        self.debug("Proxy Stopping...")

        try:
            if self.monitor_task is not None:
                self.debug("Cancelling Monitor Task")
                self.monitor_task.cancel()
                with suppress(asyncio.TimeoutError):
                    await asyncio.wait_for(
                        self.monitor_task, timeout=TASK_CANCEL_TIMEOUT_SHORT
                    )

        except Exception as e:
            self.exception(f"Proxy Stop Error: {e}")

    async def run(self) -> None:
        """Start the Base ZMQ Proxy.

        This method starts the proxy and waits for it to complete asynchronously.
        The proxy forwards messages between the frontend and backend sockets.

        Raises:
            ProxyError: If the proxy produces an error.
        """
        try:
            await self._initialize()

            self.debug("Starting Proxy...")

            if self.capture_client:
                self.monitor_task = asyncio.create_task(self._monitor_messages())
                self.debug("Proxy Message Monitoring Started")

            await asyncio.to_thread(
                zmq.proxy_steerable,
                self.frontend_socket.socket,
                self.backend_socket.socket,
                capture=self.capture_client.socket if self.capture_client else None,
                control=self.control_client.socket if self.control_client else None,
            )

        except zmq.ContextTerminated:
            self.debug("Proxy Terminated by Context")
            return

        except Exception as e:
            self.exception(f"Proxy Error: {e}")
            raise ProxyError(f"Proxy failed: {e}") from e

    async def _monitor_messages(self) -> None:
        """Monitor messages flowing through the proxy via the capture socket."""
        if not self.capture_client or not self.capture_address:
            raise ProxyError("Proxy Monitor Not Enabled")

        self.debug(
            lambda: f"Proxy Monitor Starting - Capture Address: {self.capture_address}"
        )

        capture_socket = self.context.socket(SocketType.SUB)
        capture_socket.connect(self.capture_address)
        self.debug(
            lambda: f"Proxy Monitor Connected to Capture Address: {self.capture_address}"
        )
        capture_socket.setsockopt(zmq.SUBSCRIBE, b"")  # Subscribe to all messages
        self.debug("Proxy Monitor Subscribed to all messages")

        try:
            while True:
                recv_msg = await capture_socket.recv_multipart()
                self.trace(lambda msg=recv_msg: f"Proxy Monitor Received: {msg}")
        except Exception as e:
            self.exception(f"Proxy Monitor Error - {e}")
            raise
        finally:
            capture_socket.close()

__init__(frontend_socket_class, backend_socket_class, context, zmq_proxy_config, socket_ops=None, proxy_uuid=None)

Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

Parameters:

Name Type Description Default
frontend_socket_class type[BaseZMQClient]

The frontend socket class.

required
backend_socket_class type[BaseZMQClient]

The backend socket class.

required
context Context

The ZMQ context.

required
zmq_proxy_config BaseZMQProxyConfig

The ZMQ proxy configuration.

required
socket_ops dict

Additional socket options to set.

None
proxy_uuid str

An optional UUID for the proxy instance. If not provided, a new UUID will be generated. This is useful for tracing and debugging purposes.

None
Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(
    self,
    frontend_socket_class: type[BaseZMQClient],
    backend_socket_class: type[BaseZMQClient],
    context: zmq.asyncio.Context,
    zmq_proxy_config: BaseZMQProxyConfig,
    socket_ops: dict | None = None,
    proxy_uuid: str | None = None,
) -> None:
    """Initialize the ZMQ Proxy. This is a base class for all ZMQ Proxies.

    Args:
        frontend_socket_class (type[BaseZMQClient]): The frontend socket class.
        backend_socket_class (type[BaseZMQClient]): The backend socket class.
        context (zmq.asyncio.Context): The ZMQ context.
        zmq_proxy_config (BaseZMQProxyConfig): The ZMQ proxy configuration.
        socket_ops (dict, optional): Additional socket options to set.
        proxy_uuid (str, optional): An optional UUID for the proxy instance. If not provided,
            a new UUID will be generated. This is useful for tracing and debugging purposes.
    """

    self.proxy_uuid = proxy_uuid or uuid.uuid4().hex[:8]
    self.proxy_id = f"{self.__class__.__name__.lower()}_{self.proxy_uuid}"
    super().__init__()
    self.context = context
    self.socket_ops = socket_ops

    self.monitor_task: asyncio.Task | None = None
    self.proxy_task: asyncio.Task | None = None
    self.control_client: ProxySocketClient | None = None
    self.capture_client: ProxySocketClient | None = None

    self.frontend_address = zmq_proxy_config.frontend_address
    self.backend_address = zmq_proxy_config.backend_address
    self.control_address = zmq_proxy_config.control_address
    self.capture_address = zmq_proxy_config.capture_address

    self.debug(
        lambda: f"Proxy Initializing - Frontend: {self.frontend_address}, Backend: {self.backend_address}"
    )

    self.backend_socket = backend_socket_class(
        context=self.context,
        address=self.backend_address,
        socket_ops=self.socket_ops,
        proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
    )

    self.frontend_socket = frontend_socket_class(
        context=self.context,
        address=self.frontend_address,
        socket_ops=self.socket_ops,
        proxy_uuid=self.proxy_uuid,  # Pass the proxy UUID for tracing
    )

    if self.control_address:
        self.debug(lambda: f"Proxy Control - Address: {self.control_address}")
        self.control_client = ProxySocketClient(
            context=self.context,
            socket_type=SocketType.REP,
            address=self.control_address,
            socket_ops=self.socket_ops,
            end_type=ProxyEndType.Control,
            proxy_uuid=self.proxy_uuid,
        )

    if self.capture_address:
        self.debug(lambda: f"Proxy Capture - Address: {self.capture_address}")
        self.capture_client = ProxySocketClient(
            context=self.context,
            socket_type=SocketType.PUB,
            address=self.capture_address,
            socket_ops=self.socket_ops,
            end_type=ProxyEndType.Capture,
            proxy_uuid=self.proxy_uuid,
        )

from_config(config, socket_ops=None) abstractmethod classmethod

Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
149
150
151
152
153
154
155
156
157
@classmethod
@abstractmethod
def from_config(
    cls,
    config: BaseZMQProxyConfig | None,
    socket_ops: dict | None = None,
) -> "BaseZMQProxy | None":
    """Create a BaseZMQProxy from a BaseZMQProxyConfig, or None if not provided."""
    ...

run() async

Start the Base ZMQ Proxy.

This method starts the proxy and waits for it to complete asynchronously. The proxy forwards messages between the frontend and backend sockets.

Raises:

Type Description
ProxyError

If the proxy produces an error.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
async def run(self) -> None:
    """Start the Base ZMQ Proxy.

    This method starts the proxy and waits for it to complete asynchronously.
    The proxy forwards messages between the frontend and backend sockets.

    Raises:
        ProxyError: If the proxy produces an error.
    """
    try:
        await self._initialize()

        self.debug("Starting Proxy...")

        if self.capture_client:
            self.monitor_task = asyncio.create_task(self._monitor_messages())
            self.debug("Proxy Message Monitoring Started")

        await asyncio.to_thread(
            zmq.proxy_steerable,
            self.frontend_socket.socket,
            self.backend_socket.socket,
            capture=self.capture_client.socket if self.capture_client else None,
            control=self.control_client.socket if self.control_client else None,
        )

    except zmq.ContextTerminated:
        self.debug("Proxy Terminated by Context")
        return

    except Exception as e:
        self.exception(f"Proxy Error: {e}")
        raise ProxyError(f"Proxy failed: {e}") from e

stop() async

Shutdown the BaseZMQProxy.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
async def stop(self) -> None:
    """Shutdown the BaseZMQProxy."""
    self.debug("Proxy Stopping...")

    try:
        if self.monitor_task is not None:
            self.debug("Cancelling Monitor Task")
            self.monitor_task.cancel()
            with suppress(asyncio.TimeoutError):
                await asyncio.wait_for(
                    self.monitor_task, timeout=TASK_CANCEL_TIMEOUT_SHORT
                )

    except Exception as e:
        self.exception(f"Proxy Stop Error: {e}")

ProxySocketClient

Bases: BaseZMQClient

A ZMQ Proxy socket client class that extends BaseZMQClient.

This class is used to create proxy sockets for the frontend, backend, capture, and control endpoint types of a ZMQ Proxy.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class ProxySocketClient(BaseZMQClient):
    """A ZMQ Proxy socket client class that extends BaseZMQClient.

    This class is used to create proxy sockets for the frontend, backend, capture, and control
    endpoint types of a ZMQ Proxy.
    """

    def __init__(
        self,
        context: zmq.asyncio.Context,
        socket_type: SocketType,
        address: str,
        end_type: ProxyEndType,
        socket_ops: dict | None = None,
        proxy_uuid: str | None = None,
    ) -> None:
        self.client_id = f"proxy_{end_type}_{socket_type.name.lower()}_{proxy_uuid or uuid.uuid4().hex[:8]}"
        super().__init__(
            context,
            socket_type,
            address,
            bind=True,
            socket_ops=socket_ops,
            client_id=self.client_id,
        )
        self.debug(
            lambda: f"ZMQ Proxy {end_type.name} {socket_type.name} - Address: {address}"
        )

ZMQProxyFactory

Bases: FactoryMixin[ZMQProxyType, BaseZMQProxy]

A factory for creating ZMQ proxies. see :class:FactoryMixin for more details.

Source code in aiperf/common/comms/zmq/zmq_proxy_base.py
273
274
class ZMQProxyFactory(FactoryMixin[ZMQProxyType, BaseZMQProxy]):
    """A factory for creating ZMQ proxies. see :class:`FactoryMixin` for more details."""

aiperf.common.comms.zmq.zmq_proxy_sockets

ZMQDealerRouterProxy = define_proxy_class(ZMQProxyType.DEALER_ROUTER, create_proxy_socket_class(SocketType.ROUTER, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.DEALER, ProxyEndType.Backend)) module-attribute

A ROUTER socket for the proxy's frontend and a DEALER socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌──────────────────────────────────┐ ┌───────────┐ │ DEALER │<───>│ PROXY │<────>│ ROUTER │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ ROUTER │<─────> │ DEALER │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ DEALER │<───>│ └──────────┘ └──────────┘ │<────>│ ROUTER │ │ Client N │ └──────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The ROUTER frontend socket receives messages from DEALER clients and forwards them through the proxy to ROUTER services. The ZMQ proxy handles the message routing automatically.

The DEALER backend socket receives messages from ROUTER services and forwards them through the proxy to DEALER clients. The ZMQ proxy handles the message routing automatically.

CRITICAL: This socket must NOT have an identity when used in a proxy configuration, as it needs to be transparent to preserve routing envelopes for proper response forwarding back to original DEALER clients.

ZMQPushPullProxy = define_proxy_class(ZMQProxyType.PUSH_PULL, create_proxy_socket_class(SocketType.PULL, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.PUSH, ProxyEndType.Backend)) module-attribute

A PULL socket for the proxy's frontend and a PUSH socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌─────────────────────────────────┐ ┌───────────┐ │ PUSH │─────>│ PROXY │─────>│ PULL │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ PULL │──────>│ PUSH │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ PUSH │─────>│ └──────────┘ └──────────┘ │─────>│ PULL │ │ Client N │ └─────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The PULL frontend socket receives messages from PUSH clients and forwards them through the proxy to PUSH services. The ZMQ proxy handles the message routing automatically.

The PUSH backend socket forwards messages from the proxy to PULL services. The ZMQ proxy handles the message routing automatically.

ZMQXPubXSubProxy = define_proxy_class(ZMQProxyType.XPUB_XSUB, create_proxy_socket_class(SocketType.XSUB, ProxyEndType.Frontend), create_proxy_socket_class(SocketType.XPUB, ProxyEndType.Backend)) module-attribute

An XSUB socket for the proxy's frontend and an XPUB socket for the proxy's backend.

ASCII Diagram: ┌───────────┐ ┌─────────────────────────────────┐ ┌───────────┐ │ PUB │───>│ PROXY │───>│ SUB │ │ Client 1 │ │ ┌──────────┐ ┌──────────┐ │ │ Service 1 │ └───────────┘ │ │ XSUB │──────>│ XPUB │ │ └───────────┘ ┌───────────┐ │ │ Frontend │ │ Backend │ │ ┌───────────┐ │ PUB │───>│ └──────────┘ └──────────┘ │───>│ SUB │ │ Client N │ └─────────────────────────────────┘ │ Service N │ └───────────┘ └───────────┘

The XSUB frontend socket receives messages from PUB clients and forwards them through the proxy to XPUB services. The ZMQ proxy handles the message routing automatically.

The XPUB backend socket forwards messages from the proxy to SUB services. The ZMQ proxy handles the message routing automatically.

create_proxy_socket_class(socket_type, end_type)

Create a proxy socket class using the specified socket type. This is used to reduce the boilerplate code required to create a ZMQ Proxy class.

Source code in aiperf/common/comms/zmq/zmq_proxy_sockets.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def create_proxy_socket_class(
    socket_type: SocketType, end_type: ProxyEndType
) -> type[BaseZMQClient]:
    """Create a proxy socket class using the specified socket type. This is used to
    reduce the boilerplate code required to create a ZMQ Proxy class.
    """

    class_name = f"ZMQProxy{end_type.name}Socket{socket_type.name}"

    class ProxySocket(ProxySocketClient):
        """A ZMQ Proxy socket class with a specific socket type."""

        def __init__(
            self,
            context: zmq.asyncio.Context,
            address: str,
            socket_ops: dict | None = None,
            proxy_uuid: str | None = None,
        ):
            """Initialize the ZMQ Proxy socket class."""

            super().__init__(
                context,
                socket_type,
                address,
                end_type=end_type,
                socket_ops=socket_ops,
                proxy_uuid=proxy_uuid,
            )

    # Dynamically set the class name and qualname based on the socket and end type
    ProxySocket.__name__ = class_name
    ProxySocket.__qualname__ = class_name
    ProxySocket.__doc__ = f"A ZMQ Proxy {end_type.name} socket implementation."
    return ProxySocket

define_proxy_class(proxy_type, frontend_socket_class, backend_socket_class)

This function reduces the boilerplate code required to create a ZMQ Proxy class. It will generate a ZMQ Proxy class and register it with the ZMQProxyFactory.

Parameters:

Name Type Description Default
proxy_type ZMQProxyType

The type of proxy to generate.

required
frontend_socket_class type[BaseZMQClient]

The class of the frontend socket.

required
backend_socket_class type[BaseZMQClient]

The class of the backend socket.

required
Source code in aiperf/common/comms/zmq/zmq_proxy_sockets.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def define_proxy_class(
    proxy_type: ZMQProxyType,
    frontend_socket_class: type[BaseZMQClient],
    backend_socket_class: type[BaseZMQClient],
) -> type[BaseZMQProxy]:
    """This function reduces the boilerplate code required to create a ZMQ Proxy class.
    It will generate a ZMQ Proxy class and register it with the ZMQProxyFactory.

    Args:
        proxy_type: The type of proxy to generate.
        frontend_socket_class: The class of the frontend socket.
        backend_socket_class: The class of the backend socket.
    """

    class ZMQProxy(BaseZMQProxy):
        """
        A Generated ZMQ Proxy class.

        This class is responsible for creating the ZMQ proxy that forwards messages
        between frontend and backend sockets.
        """

        def __init__(
            self,
            context: zmq.asyncio.Context,
            zmq_proxy_config: BaseZMQProxyConfig,
            socket_ops: dict | None = None,
        ) -> None:
            super().__init__(
                frontend_socket_class=frontend_socket_class,
                backend_socket_class=backend_socket_class,
                context=context,
                zmq_proxy_config=zmq_proxy_config,
                socket_ops=socket_ops,
            )

        @classmethod
        def from_config(
            cls,
            config: BaseZMQProxyConfig | None,
            socket_ops: dict | None = None,
        ) -> "ZMQProxy | None":
            if config is None:
                return None
            return cls(
                context=zmq.asyncio.Context.instance(),
                zmq_proxy_config=config,
                socket_ops=socket_ops,
            )

    # Dynamically set the class name and qualname based on the proxy type
    ZMQProxy.__name__ = f"ZMQ_{proxy_type.name}_Proxy"
    ZMQProxy.__qualname__ = ZMQProxy.__name__
    ZMQProxy.__doc__ = f"A ZMQ Proxy for {proxy_type.name} communication."
    ZMQProxyFactory.register(proxy_type)(ZMQProxy)
    return ZMQProxy

aiperf.common.config.audio_config

AudioConfig

Bases: BaseConfig

A configuration class for defining audio related settings.

Source code in aiperf/common/config/audio_config.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class AudioConfig(BaseConfig):
    """
    A configuration class for defining audio related settings.
    """

    _GROUP_NAME = "Input Audio"

    batch_size: Annotated[
        int,
        Field(
            ge=0,
            description="The batch size of audio requests AIPerf should send.\n"
            "This is currently supported with the OpenAI `multimodal` endpoint type",
        ),
        cyclopts.Parameter(
            name=(
                "--audio-batch-size",
                "--batch-size-audio",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.BATCH_SIZE

    length: AudioLengthConfig = AudioLengthConfig()

    format: Annotated[
        AudioFormat,
        Field(
            description="The format of the audio files (wav or mp3).",
        ),
        cyclopts.Parameter(
            name=(
                "--audio-format",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.FORMAT

    depths: Annotated[
        list[int],
        Field(
            min_length=1,
            description="A list of audio bit depths to randomly select from in bits.",
        ),
        BeforeValidator(parse_str_or_list_of_positive_values),
        cyclopts.Parameter(
            name=(
                "--audio-depths",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.DEPTHS

    sample_rates: Annotated[
        list[float],
        Field(
            min_length=1,
            description="A list of audio sample rates to randomly select from in kHz.\n"
            "Common sample rates are 16, 44.1, 48, 96, etc.",
        ),
        BeforeValidator(parse_str_or_list_of_positive_values),
        cyclopts.Parameter(
            name=(
                "--audio-sample-rates",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.SAMPLE_RATES

    num_channels: Annotated[
        int,
        Field(
            ge=1,
            le=2,
            description="The number of audio channels to use for the audio data generation.",
        ),
        cyclopts.Parameter(
            name=(
                "--audio-num-channels",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.NUM_CHANNELS

AudioLengthConfig

Bases: BaseConfig

A configuration class for defining audio length related settings.

Source code in aiperf/common/config/audio_config.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class AudioLengthConfig(BaseConfig):
    """
    A configuration class for defining audio length related settings.
    """

    _GROUP_NAME = "Input Audio"

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean length of the audio in seconds.",
        ),
        cyclopts.Parameter(
            name=(
                "--audio-length-mean",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.LENGTH_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the length of the audio in seconds.",
        ),
        cyclopts.Parameter(
            name=(
                "--audio-length-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = AudioDefaults.LENGTH_STDDEV

aiperf.common.config.base_config

BaseConfig

Bases: AIPerfBaseModel

Base configuration class for all configurations.

Source code in aiperf/common/config/base_config.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class BaseConfig(AIPerfBaseModel):
    """
    Base configuration class for all configurations.
    """

    def serialize_to_yaml(self, verbose: bool = False, indent: int = 4) -> str:
        """
        Serialize a Pydantic model to a YAML string.

        Args:
            verbose: Whether to include verbose comments in the YAML output.
            indent: The per-level indentation to use.
        """
        # Dump model to dict with context (flags propagate recursively)
        context = {
            "verbose": verbose,
        }

        data = self.model_dump(context=context)

        # Attach comments recursively
        commented_data = self._attach_comments(
            data=data,
            model=self,
            context=context,
            indent=indent,
        )

        # Dump to YAML
        yaml = YAML(pure=True)
        yaml.indent(mapping=indent, sequence=indent, offset=indent)

        stream = io.StringIO()
        yaml.dump(commented_data, stream)
        return stream.getvalue()

    @staticmethod
    def _attach_comments(
        data: Any,
        model: AIPerfBaseModel,
        context: dict,
        indent: int,
        indent_level: int = 0,
    ) -> Any:
        """
        Recursively convert dicts to ruamel.yaml CommentedMap and attach comments from
        Pydantic field descriptions, or based on context (e.g., verbose flag).

        Args:
            data: The raw data to convert to a CommentedMap.
            model: The Pydantic model that contains the field descriptions.
            context: The Pydantic serializer context which contains the serializer flags.
            indent: The per-level indentation to use for the comments.
            indent_level: The current level of indentation. The actual indentation is
                `indent * indent_level`.

        Returns:
            The data with comments attached.
        """
        if isinstance(data, dict):
            # Create a CommentedMap to store the commented data. This is a special type of
            # dict provided by the ruamel.yaml library that preserves the order of the keys and
            # allows for comments to be attached to the keys.
            commented_map = CommentedMap()

            for field_name, value in data.items():
                field = model.__class__.model_fields.get(field_name)

                if not BaseConfig._should_add_field_to_template(field):
                    continue

                if BaseConfig._is_a_nested_config(field, value):
                    # Recursively process nested models
                    commented_map[field_name] = BaseConfig._attach_comments(
                        value,
                        getattr(model, field_name),
                        context=context,
                        indent=indent,
                        indent_level=indent_level + 1,
                    )

                    commented_map.yaml_set_comment_before_after_key(
                        field_name,
                        before="\n",
                        indent=indent * (indent_level + 1),
                    )
                else:
                    # Attach the value to the commented map
                    commented_map[field_name] = BaseConfig._preprocess_value(value)

                # Attach comment if verbose and description exists
                if context.get("verbose") and field and field.description:
                    # Set the comment before the key, with the specified indentation
                    commented_map.yaml_set_comment_before_after_key(
                        field_name,
                        before="\n" + field.description,
                        indent=indent * indent_level,
                    )

            return commented_map

    @staticmethod
    def _should_add_field_to_template(field: Any) -> bool:
        # Check if the field should be added to the template based on json_schema_extra
        # and the add_to_template flag.
        # If add_to_template is False, we skip adding the field to the template.
        # If add_to_template is True or not present, we include the field in the template.
        if field and field.json_schema_extra:
            return field.json_schema_extra.get(ADD_TO_TEMPLATE, True)
        else:
            return True

    @staticmethod
    def _is_a_nested_config(field: Any, value: Any) -> bool:
        return (
            isinstance(value, dict)
            and field
            and issubclass(field.annotation, AIPerfBaseModel)
        )

    @staticmethod
    def _preprocess_value(value: Any) -> Any:
        """
        Preprocess the value before serialization.
        """

        if isinstance(value, Enum):
            return str(value.value).lower()
        elif isinstance(value, Path):
            return str(value)
        else:
            return value

serialize_to_yaml(verbose=False, indent=4)

Serialize a Pydantic model to a YAML string.

Parameters:

Name Type Description Default
verbose bool

Whether to include verbose comments in the YAML output.

False
indent int

The per-level indentation to use.

4
Source code in aiperf/common/config/base_config.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def serialize_to_yaml(self, verbose: bool = False, indent: int = 4) -> str:
    """
    Serialize a Pydantic model to a YAML string.

    Args:
        verbose: Whether to include verbose comments in the YAML output.
        indent: The per-level indentation to use.
    """
    # Dump model to dict with context (flags propagate recursively)
    context = {
        "verbose": verbose,
    }

    data = self.model_dump(context=context)

    # Attach comments recursively
    commented_data = self._attach_comments(
        data=data,
        model=self,
        context=context,
        indent=indent,
    )

    # Dump to YAML
    yaml = YAML(pure=True)
    yaml.indent(mapping=indent, sequence=indent, offset=indent)

    stream = io.StringIO()
    yaml.dump(commented_data, stream)
    return stream.getvalue()

aiperf.common.config.config_defaults

aiperf.common.config.config_validators

parse_file(value)

Parses the given string value and returns a Path object if the value represents a valid file or directory. Returns None if the input value is empty. Args: value (str): The string value to parse. Returns: Optional[Path]: A Path object if the value is valid, or None if the value is empty. Raises: ValueError: If the value is not a valid file or directory.

Source code in aiperf/common/config/config_validators.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def parse_file(value: str | None) -> Path | None:
    """
    Parses the given string value and returns a Path object if the value represents
    a valid file or directory. Returns None if the input value is empty.
    Args:
        value (str): The string value to parse.
    Returns:
        Optional[Path]: A Path object if the value is valid, or None if the value is empty.
    Raises:
        ValueError: If the value is not a valid file or directory.
    """

    if not value:
        return None
    elif not isinstance(value, str):
        raise ValueError(f"Expected a string, but got {type(value).__name__}")
    else:
        path = Path(value)
        if path.is_file() or path.is_dir():
            return path
        else:
            raise ValueError(f"'{value}' is not a valid file or directory")

parse_goodput(goodputs)

Parses and validates a dictionary of goodput values, ensuring that all values are non-negative integers or floats, and converts them to floats. Args: goodputs (Dict[str, Any]): A dictionary where keys are target metric names (strings) and values are the corresponding goodput values. Returns: Dict[str, float]: A dictionary with the same keys as the input, but with all values converted to floats. Raises: ValueError: If any value in the input dictionary is not an integer or float, or if any value is negative.

Source code in aiperf/common/config/config_validators.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def parse_goodput(goodputs: dict[str, Any]) -> dict[str, float]:
    """
    Parses and validates a dictionary of goodput values, ensuring that all values
    are non-negative integers or floats, and converts them to floats.
    Args:
        goodputs (Dict[str, Any]): A dictionary where keys are target metric names
            (strings) and values are the corresponding goodput values.
    Returns:
        Dict[str, float]: A dictionary with the same keys as the input, but with
            all values converted to floats.
    Raises:
        ValueError: If any value in the input dictionary is not an integer or float,
            or if any value is negative.
    """

    constraints = {}
    for target_metric, target_value in goodputs.items():
        if isinstance(target_value, (int | float)):
            if target_value < 0:
                raise ValueError(
                    f"User Config: Goodput values must be non-negative ({target_metric}: {target_value})"
                )

            constraints[target_metric] = float(target_value)
        else:
            raise ValueError("User Config: Goodput values must be integers or floats")

    return constraints

parse_service_types(input)

Parses the input to ensure it is a set of service types. Will replace hyphens with underscores for user convenience.

Source code in aiperf/common/config/config_validators.py
73
74
75
76
77
78
79
80
81
82
def parse_service_types(input: Any | None) -> set[ServiceType] | None:
    """Parses the input to ensure it is a set of service types.
    Will replace hyphens with underscores for user convenience."""
    if input is None:
        return None

    return {
        ServiceType(service_type.replace("-", "_"))
        for service_type in parse_str_or_csv_list(input)
    }

parse_str_or_csv_list(input)

Parses the input to ensure it is either a string or a list. If the input is a string, it splits the string by commas and trims any whitespace around each element, returning the result as a list. If the input is already a list, it will split each item by commas and trim any whitespace around each element, returning the combined result as a list. If the input is neither a string nor a list, a ValueError is raised.

[1, 2, 3] -> [1, 2, 3] "1,2,3" -> ["1", "2", "3"]["1,2,3", "4,5,6"] -> ["1", "2", "3", "4", "5", "6"]["1,2,3", 4, 5] -> ["1", "2", "3", 4, 5]

Source code in aiperf/common/config/config_validators.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def parse_str_or_csv_list(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is either a string or a list. If the input is a string,
    it splits the string by commas and trims any whitespace around each element, returning
    the result as a list. If the input is already a list, it will split each item by commas
    and trim any whitespace around each element, returning the combined result as a list.
    If the input is neither a string nor a list, a ValueError is raised.

    [1, 2, 3] -> [1, 2, 3]
    "1,2,3" -> ["1", "2", "3"]
    ["1,2,3", "4,5,6"] -> ["1", "2", "3", "4", "5", "6"]
    ["1,2,3", 4, 5] -> ["1", "2", "3", 4, 5]
    """
    if isinstance(input, str):
        output = [item.strip() for item in input.split(",")]
    elif isinstance(input, list):
        output = []
        for item in input:
            if isinstance(item, str):
                output.extend([token.strip() for token in item.split(",")])
            else:
                output.append(item)
    else:
        raise ValueError(f"User Config: {input} - must be a string or list")

    return output

parse_str_or_dict(input)

Parses the input to ensure it is a dictionary.

  • If the input is a string:
    • If the string starts with a '{', it is parsed as a JSON string.
    • Otherwise, it splits the string by commas and then for each item, it splits the item by colons into key and value, trims any whitespace.
  • If the input is already a dictionary, it is returned as-is.
  • If the input is a list, it is converted to a dictionary by splitting each string by colons into key and value, trims any whitespace.
  • Otherwise, a ValueError is raised.

Parameters:

Name Type Description Default
input Any

The input to be parsed. Expected to be a string, list, or dictionary.

required

Returns: dict[str, Any]: A dictionary derived from the input. Raises: ValueError: If the input is neither a string, list, nor dictionary, or if the parsing fails.

Source code in aiperf/common/config/config_validators.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def parse_str_or_dict(input: Any | None) -> dict[str, Any] | None:
    """
    Parses the input to ensure it is a dictionary.

    - If the input is a string:
        - If the string starts with a '{', it is parsed as a JSON string.
        - Otherwise, it splits the string by commas and then for each item, it splits the item by colons
        into key and value, trims any whitespace.
    - If the input is already a dictionary, it is returned as-is.
    - If the input is a list, it is converted to a dictionary by splitting each string by colons
    into key and value, trims any whitespace.
    - Otherwise, a ValueError is raised.

    Args:
        input (Any): The input to be parsed. Expected to be a string, list, or dictionary.
    Returns:
        dict[str, Any]: A dictionary derived from the input.
    Raises:
        ValueError: If the input is neither a string, list, nor dictionary, or if the parsing fails.
    """

    if input is None:
        return None

    if isinstance(input, dict):
        return input

    if isinstance(input, list):
        return {
            key.strip(): value.strip()
            for item in input
            for key, value in [item.split(":")]
        }

    if isinstance(input, str):
        if input.startswith("{"):
            try:
                return json.loads(input)
            except json.JSONDecodeError as e:
                raise ValueError(
                    f"User Config: {input} - must be a valid JSON string"
                ) from e
        else:
            return {
                key.strip(): value.strip()
                for item in input.split(",")
                for key, value in [item.split(":")]
            }

    raise ValueError(f"User Config: {input} - must be a valid string, list, or dict")

parse_str_or_list(input)

Parses the input to ensure it is either a string or a list. If the input is a string, it splits the string by commas and trims any whitespace around each element, returning the result as a list. If the input is already a list, it is returned as-is. If the input is neither a string nor a list, a ValueError is raised. Args: input (Any): The input to be parsed. Expected to be a string or a list. Returns: list: A list of strings derived from the input. Raises: ValueError: If the input is neither a string nor a list.

Source code in aiperf/common/config/config_validators.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def parse_str_or_list(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is either a string or a list. If the input is a string,
    it splits the string by commas and trims any whitespace around each element, returning
    the result as a list. If the input is already a list, it is returned as-is. If the input
    is neither a string nor a list, a ValueError is raised.
    Args:
        input (Any): The input to be parsed. Expected to be a string or a list.
    Returns:
        list: A list of strings derived from the input.
    Raises:
        ValueError: If the input is neither a string nor a list.
    """
    if isinstance(input, str):
        output = [item.strip() for item in input.split(",")]
    elif isinstance(input, list):
        # TODO: When using cyclopts, the values are already lists, so we have to split them by commas.
        output = []
        for item in input:
            if isinstance(item, str):
                output.extend([token.strip() for token in item.split(",")])
            else:
                output.append(item)
    else:
        raise ValueError(f"User Config: {input} - must be a string or list")

    return output

parse_str_or_list_of_positive_values(input)

Parses the input to ensure it is a list of positive integers or floats. This function first converts the input into a list using parse_str_or_list. It then validates that each value in the list is either an integer or a float and that all values are strictly greater than zero. If any value fails this validation, a ValueError is raised. Args: input (Any): The input to be parsed. It can be a string or a list. Returns: List[Any]: A list of positive integers or floats. Raises: ValueError: If any value in the parsed list is not a positive integer or float.

Source code in aiperf/common/config/config_validators.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def parse_str_or_list_of_positive_values(input: Any) -> list[Any]:
    """
    Parses the input to ensure it is a list of positive integers or floats.
    This function first converts the input into a list using `parse_str_or_list`.
    It then validates that each value in the list is either an integer or a float
    and that all values are strictly greater than zero. If any value fails this
    validation, a `ValueError` is raised.
    Args:
        input (Any): The input to be parsed. It can be a string or a list.
    Returns:
        List[Any]: A list of positive integers or floats.
    Raises:
        ValueError: If any value in the parsed list is not a positive integer or float.
    """

    output = parse_str_or_list(input)

    for value in output:
        if not isinstance(value, (int | float)) or value <= 0:
            raise ValueError(
                f"User Config: {output} - all values {value} must be a positive integer or float"
            )

    return output

aiperf.common.config.conversation_config

ConversationConfig

Bases: BaseConfig

A configuration class for defining conversations related settings.

Source code in aiperf/common/config/conversation_config.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class ConversationConfig(BaseConfig):
    """
    A configuration class for defining conversations related settings.
    """

    _GROUP_NAME = "Input Conversation"

    num: Annotated[
        int,
        Field(
            ge=1,
            description="The total number of unique conversations to generate.\n"
            "Each conversation represents a single request session between client and server.\n"
            "Supported on synthetic mode only and conversations will be reused until benchmarking is complete.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-num",
                "--num-conversations",
                "--num-sessions",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ConversationDefaults.NUM

    turn: TurnConfig = TurnConfig()

TurnConfig

Bases: BaseConfig

A configuration class for defining turn related settings in a conversation.

Source code in aiperf/common/config/conversation_config.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class TurnConfig(BaseConfig):
    """
    A configuration class for defining turn related settings in a conversation.
    """

    _GROUP_NAME = "Input Conversation"

    mean: Annotated[
        int,
        Field(
            ge=1,
            description="The mean number of turns within a conversation.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-turn-mean",
                "--session-turns-mean",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = TurnDefaults.MEAN

    stddev: Annotated[
        int,
        Field(
            ge=0,
            description="The standard deviation of the number of turns within a conversation.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-turn-stddev",
                "--session-turns-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = TurnDefaults.STDDEV

    delay: TurnDelayConfig = TurnDelayConfig()

TurnDelayConfig

Bases: BaseConfig

A configuration class for defining turn delay related settings.

Source code in aiperf/common/config/conversation_config.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class TurnDelayConfig(BaseConfig):
    """
    A configuration class for defining turn delay related settings.
    """

    _GROUP_NAME = "Input Conversation"

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean delay between turns within a conversation in milliseconds.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-turn-delay-mean",
                "--session-turn-delay-mean",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = TurnDelayDefaults.MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the delay between turns \n"
            "within a conversation in milliseconds.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-turn-delay-stddev",
                "--session-turn-delay-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = TurnDelayDefaults.STDDEV

    ratio: Annotated[
        float,
        Field(
            ge=0,
            description="A ratio to scale multi-turn delays.",
        ),
        cyclopts.Parameter(
            name=(
                "--conversation-turn-delay-ratio",
                "--session-delay-ratio",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = TurnDelayDefaults.RATIO

aiperf.common.config.endpoint_config

EndPointConfig

Bases: BaseConfig

A configuration class for defining endpoint related settings.

Source code in aiperf/common/config/endpoint_config.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class EndPointConfig(BaseConfig):
    """
    A configuration class for defining endpoint related settings.
    """

    _GROUP_NAME = "Endpoint"

    model_selection_strategy: Annotated[
        ModelSelectionStrategy,
        Field(
            description="When multiple models are specified, this is how a specific model should be assigned to a prompt.\n"
            "round_robin: nth prompt in the list gets assigned to n-mod len(models).\n"
            "random: assignment is uniformly random",
        ),
        cyclopts.Parameter(
            name=(
                "--model-selection-strategy",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.MODEL_SELECTION_STRATEGY

    custom_endpoint: Annotated[
        str | None,
        Field(
            description="Set a custom endpoint that differs from the OpenAI defaults.",
        ),
        cyclopts.Parameter(
            name=(
                "--custom-endpoint",
                "--endpoint",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.CUSTOM_ENDPOINT

    type: Annotated[
        EndpointType,
        Field(
            description="The type to send requests to on the server.",
        ),
        cyclopts.Parameter(
            name=(
                "--endpoint-type",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.TYPE

    streaming: Annotated[
        bool,
        Field(
            description="An option to enable the use of the streaming API.",
        ),
        cyclopts.Parameter(
            name=(
                "--streaming",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.STREAMING

    server_metrics_urls: Annotated[
        list[str],
        Field(
            description="The list of Triton server metrics URLs.\n"
            "These are used for Telemetry metric reporting with Triton.",
        ),
        BeforeValidator(parse_str_or_list),
        cyclopts.Parameter(
            name=(
                "--server-metrics-urls",  # GenAI-Perf
                "--server-metrics-url",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.SERVER_METRICS_URLS

    url: Annotated[
        str,
        Field(
            description="URL of the endpoint to target for benchmarking.",
        ),
        cyclopts.Parameter(
            name=(
                "--url",  # GenAI-Perf
                "-u",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.URL

    grpc_method: Annotated[
        str,
        Field(
            description="A fully-qualified gRPC method name in "
            "'<package>.<service>/<method>' format.\n"
            "The option is only supported by dynamic gRPC service kind and is\n"
            "required to identify the RPC to use when sending requests to the server.",
        ),
        cyclopts.Parameter(
            name=(
                "--grpc-method",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.GRPC_METHOD

    # NEW AIPerf Option
    timeout_seconds: Annotated[
        float,
        Field(
            description="The timeout in floating points seconds for each request to the endpoint.",
        ),
        cyclopts.Parameter(
            name=("--request-timeout-seconds"),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.TIMEOUT

    # NEW AIPerf Option
    api_key: Annotated[
        str | None,
        Field(
            description="The API key to use for the endpoint. If provided, it will be sent with every request as"
            "as a header: `Authorization: Bearer <api_key>`.",
        ),
        cyclopts.Parameter(
            name=("--api-key"),
            group=_GROUP_NAME,
        ),
    ] = EndPointDefaults.API_KEY

aiperf.common.config.image_config

ImageConfig

Bases: BaseConfig

A configuration class for defining image related settings.

Source code in aiperf/common/config/image_config.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class ImageConfig(BaseConfig):
    """
    A configuration class for defining image related settings.
    """

    _GROUP_NAME = "Input Image"

    width: ImageWidthConfig = ImageWidthConfig()
    height: ImageHeightConfig = ImageHeightConfig()
    batch_size: Annotated[
        int,
        Field(
            ge=0,
            description="The image batch size of the requests AIPerf should send.\n"
            "This is currently supported with the image retrieval endpoint type.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-batch-size",
                "--batch-size-image",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.BATCH_SIZE

    format: Annotated[
        ImageFormat,
        Field(
            description="The compression format of the images.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-format",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.FORMAT

ImageHeightConfig

Bases: BaseConfig

A configuration class for defining image height related settings.

Source code in aiperf/common/config/image_config.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class ImageHeightConfig(BaseConfig):
    """
    A configuration class for defining image height related settings.
    """

    _GROUP_NAME = "Input Image"

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean height of images when generating synthetic image data.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-height-mean",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.HEIGHT_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of height of images when generating synthetic image data.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-height-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.HEIGHT_STDDEV

ImageWidthConfig

Bases: BaseConfig

A configuration class for defining image width related settings.

Source code in aiperf/common/config/image_config.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ImageWidthConfig(BaseConfig):
    """
    A configuration class for defining image width related settings.
    """

    _GROUP_NAME = "Input Image"

    mean: Annotated[
        float,
        Field(
            ge=0,
            description="The mean width of images when generating synthetic image data.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-width-mean",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.WIDTH_MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of width of images when generating synthetic image data.",
        ),
        cyclopts.Parameter(
            name=(
                "--image-width-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = ImageDefaults.WIDTH_STDDEV

aiperf.common.config.input_config

InputConfig

Bases: BaseConfig

A configuration class for defining input related settings.

Source code in aiperf/common/config/input_config.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class InputConfig(BaseConfig):
    """
    A configuration class for defining input related settings.
    """

    _GROUP_NAME = "Input"

    @model_validator(mode="after")
    def validate_fixed_schedule(self) -> Self:
        """Validate the fixed schedule configuration."""
        if self.fixed_schedule and self.file is None:
            raise ValueError("Fixed schedule requires a file to be provided")
        if self.file is not None:
            self.fixed_schedule = True
            logger.debug("Fixed schedule is enabled because file is provided")
        return self

    extra: Annotated[
        dict[str, Any] | None,
        Field(
            description="Provide additional inputs to include with every request.\n"
            "Inputs should be in an 'input_name:value' format.",
        ),
        cyclopts.Parameter(
            name=(
                "--extra-inputs",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
        BeforeValidator(parse_str_or_dict),
    ] = InputDefaults.EXTRA

    goodput: Annotated[
        dict[str, Any],
        Field(
            description="An option to provide constraints in order to compute goodput.\n"
            "Specify goodput constraints as 'key:value' pairs,\n"
            "where the key is a valid metric name, and the value is a number representing\n"
            "either milliseconds or a throughput value per second.\n"
            "For example: request_latency:300,output_token_throughput_per_user:600",
        ),
        cyclopts.Parameter(
            name=(
                "--goodput",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
        BeforeValidator(parse_goodput),
    ] = InputDefaults.GOODPUT

    headers: Annotated[
        dict[str, str] | None,
        Field(
            description="Adds a custom header to the requests.\n"
            "Headers must be specified as 'Header:Value' pairs.",
        ),
        BeforeValidator(parse_str_or_dict),
        cyclopts.Parameter(
            name=(
                "--header",  # GenAI-Perf
                "-H",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputDefaults.HEADERS

    file: Annotated[
        Any,
        Field(
            description="The file or directory path that contains the dataset to use for profiling.\n"
            "This parameter is used in conjunction with the `custom_dataset_type` parameter\n"
            "to support different types of user provided datasets.",
        ),
        BeforeValidator(parse_file),
        cyclopts.Parameter(
            name=(
                "--input-file",  # GenAI-Perf,
            ),
            group=_GROUP_NAME,
        ),
    ] = InputDefaults.FILE

    fixed_schedule: Annotated[
        bool,
        Field(
            description="Specifies to run a fixed schedule of requests. This is normally inferred from the --input-file parameter, but can be set manually here."
        ),
        cyclopts.Parameter(
            name=(
                "--fixed-schedule",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputDefaults.FIXED_SCHEDULE

    # NEW AIPerf Option
    custom_dataset_type: Annotated[
        CustomDatasetType,
        Field(
            description="The type of custom dataset to use.\n"
            "This parameter is used in conjunction with the --file parameter.",
        ),
        cyclopts.Parameter(
            name=("--custom-dataset-type"),
            group=_GROUP_NAME,
        ),
    ] = InputDefaults.CUSTOM_DATASET_TYPE

    random_seed: Annotated[
        int | None,
        Field(
            default=None,
            description="The seed used to generate random values.\n"
            "Set to some value to make the synthetic data generation deterministic.\n"
            "It will use system default if not provided.",
        ),
        cyclopts.Parameter(
            name=(
                "--random-seed",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputDefaults.RANDOM_SEED

    audio: AudioConfig = AudioConfig()
    image: ImageConfig = ImageConfig()
    prompt: PromptConfig = PromptConfig()
    conversation: ConversationConfig = ConversationConfig()

validate_fixed_schedule()

Validate the fixed schedule configuration.

Source code in aiperf/common/config/input_config.py
34
35
36
37
38
39
40
41
42
@model_validator(mode="after")
def validate_fixed_schedule(self) -> Self:
    """Validate the fixed schedule configuration."""
    if self.fixed_schedule and self.file is None:
        raise ValueError("Fixed schedule requires a file to be provided")
    if self.file is not None:
        self.fixed_schedule = True
        logger.debug("Fixed schedule is enabled because file is provided")
    return self

aiperf.common.config.loader

load_service_config()

Load the service configuration.

Source code in aiperf/common/config/loader.py
 7
 8
 9
10
def load_service_config() -> ServiceConfig:
    """Load the service configuration."""
    # TODO: implement
    return ServiceConfig()

load_user_config()

Load the user configuration.

Source code in aiperf/common/config/loader.py
13
14
15
16
def load_user_config() -> UserConfig:
    """Load the user configuration."""
    # TODO: implement
    raise NotImplementedError("User configuration is not implemented")

aiperf.common.config.loadgen_config

LoadGeneratorConfig

Bases: BaseConfig

A configuration class for defining top-level load generator settings.

Source code in aiperf/common/config/loadgen_config.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class LoadGeneratorConfig(BaseConfig):
    """
    A configuration class for defining top-level load generator settings.
    """

    _GROUP_NAME = "Load Generator"

    # TODO: Potentially add a validator to ensure that the concurrency is not greater than the request count
    concurrency: Annotated[
        int,
        Field(
            ge=1,
            description="The concurrency value to benchmark.",
        ),
        cyclopts.Parameter(
            name=(
                "--concurrency",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = LoadGeneratorDefaults.CONCURRENCY

    request_rate: Annotated[
        float | None,
        Field(
            gt=0,
            description="Sets the request rate for the load generated by AIPerf. Unit: requests/second",
        ),
        cyclopts.Parameter(
            name=(
                "--request-rate",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = LoadGeneratorDefaults.REQUEST_RATE

    # NEW AIPerf Option
    request_rate_mode: Annotated[
        RequestRateMode,
        Field(
            description="Sets the request rate mode for the load generated by AIPerf. Valid values: constant, poisson.\n"
            "constant: Generate requests at a fixed rate.\n"
            "poisson: Generate requests using a poisson distribution.\n"
            f"The default is {LoadGeneratorDefaults.REQUEST_RATE_MODE}.",
        ),
        cyclopts.Parameter(
            name=("--request-rate-mode"),
            group=_GROUP_NAME,
        ),
    ] = LoadGeneratorDefaults.REQUEST_RATE_MODE

    request_count: Annotated[
        int,
        Field(
            ge=1,
            description="The number of requests to use for measurement.",
        ),
        cyclopts.Parameter(
            name=(
                "--request-count",  # GenAI-Perf
                "--num-requests",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = LoadGeneratorDefaults.REQUEST_COUNT

    warmup_request_count: Annotated[
        int,
        Field(
            ge=0,
            description="The number of warmup requests to send before benchmarking.",
        ),
        cyclopts.Parameter(
            name=(
                "--warmup-request-count",  # GenAI-Perf
                "--num-warmup-requests",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = LoadGeneratorDefaults.WARMUP_REQUEST_COUNT

aiperf.common.config.measurement_config

MeasurementConfig

Bases: BaseConfig

A configuration class for defining top-level measurement settings.

Source code in aiperf/common/config/measurement_config.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class MeasurementConfig(BaseConfig):
    """
    A configuration class for defining top-level measurement settings.
    """

    _GROUP_NAME = "Measurement"

    # TODO: Not implemented yet
    measurement_interval: Annotated[
        float,
        Field(
            ge=1,
            le=1_000_000,
            description="The time interval used for each measurement in milliseconds. "
            "AIPerf will sample a time interval specified and take "
            "measurement over the requests completed within that time interval. "
            "When using the default stability percentage, AIPerf will benchmark  "
            "for 3*(measurement_interval) milliseconds.",
        ),
        cyclopts.Parameter(
            name=(
                "--measurement-interval-ms",
                "--measurement-interval",  # GenAI-Perf
                "-p",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = MeasurementDefaults.MEASUREMENT_INTERVAL

    # TODO: Not implemented yet
    stability_percentage: Annotated[
        float,
        Field(
            gt=0.0,
            lt=1.0,
            description="The allowed variation in latency measurements when determining if a result is stable.\n"
            "The measurement is considered as stable if the ratio of max / min\n"
            "from the recent 3 measurements is within (stability percentage)\n"
            "in terms of both infer per second and latency.",
        ),
        cyclopts.Parameter(
            name=(
                "--stability-percentage",  # GenAI-Perf
                "-s",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = MeasurementDefaults.STABILITY_PERCENTAGE

aiperf.common.config.output_config

OutputConfig

Bases: BaseConfig

A configuration class for defining output related settings.

Source code in aiperf/common/config/output_config.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class OutputConfig(BaseConfig):
    """
    A configuration class for defining output related settings.
    """

    _GROUP_NAME = "Output"

    artifact_directory: Annotated[
        Path,
        Field(
            description="The directory to store all the (output) artifacts generated by AIPerf.",
        ),
        cyclopts.Parameter(
            name=(
                "--output-artifact-dir",
                "--artifact-dir",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = OutputDefaults.ARTIFACT_DIRECTORY

aiperf.common.config.prompt_config

InputTokensConfig

Bases: BaseConfig

A configuration class for defining input token related settings.

Source code in aiperf/common/config/prompt_config.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class InputTokensConfig(BaseConfig):
    """
    A configuration class for defining input token related settings.
    """

    _GROUP_NAME = "Input Sequence Length"

    mean: Annotated[
        int,
        Field(
            ge=0,
            description="The mean of number of tokens in the generated prompts when using synthetic data.",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-input-tokens-mean",
                "--synthetic-input-tokens-mean",  # GenAI-Perf
                "--isl",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputTokensDefaults.MEAN

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of number of tokens in the generated prompts when using synthetic data.",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-input-tokens-stddev",
                "--synthetic-input-tokens-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputTokensDefaults.STDDEV

    # NEW AIPerf Option
    block_size: Annotated[
        int,
        Field(
            default=512,
            description="The block size of the prompt.",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-input-tokens-block-size",
                "--synthetic-input-tokens-block-size",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = InputTokensDefaults.BLOCK_SIZE

OutputTokensConfig

Bases: BaseConfig

A configuration class for defining output token related settings.

Source code in aiperf/common/config/prompt_config.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class OutputTokensConfig(BaseConfig):
    """
    A configuration class for defining output token related settings.
    """

    _GROUP_NAME = "Output Sequence Length"

    mean: Annotated[
        int,
        Field(
            ge=0,
            description="The mean number of tokens in each output.",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-output-tokens-mean",
                "--output-tokens-mean",  # GenAI-Perf
                "--osl",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = OutputTokensDefaults.MEAN

    deterministic: Annotated[
        bool,
        Field(
            description=(
                "This can be set to improve the precision of the mean by setting the\n"
                "minimum number of tokens equal to the requested number of tokens.\n"
                "This is currently supported with Triton."
            ),
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-output-tokens-deterministic",
                "--output-tokens-mean-deterministic",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = OutputTokensDefaults.DETERMINISTIC

    stddev: Annotated[
        float,
        Field(
            ge=0,
            description="The standard deviation of the number of tokens in each output.",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-output-tokens-stddev",
                "--output-tokens-stddev",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = OutputTokensDefaults.STDDEV

PrefixPromptConfig

Bases: BaseConfig

A configuration class for defining prefix prompt related settings.

Source code in aiperf/common/config/prompt_config.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class PrefixPromptConfig(BaseConfig):
    """
    A configuration class for defining prefix prompt related settings.
    """

    _GROUP_NAME = "Prefix Prompt"

    pool_size: Annotated[
        int,
        Field(
            ge=0,
            description=(
                "The total size of the prefix prompt pool to select prefixes from.\n"
                "If this value is not zero, these are prompts that are prepended to input prompts.\n"
                "This is useful for benchmarking models that use a K-V cache."
            ),
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-prefix-pool-size",
                "--prefix-prompt-pool-size",
                "--num-prefix-prompts",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = PrefixPromptDefaults.POOL_SIZE

    length: Annotated[
        int,
        Field(
            ge=0,
            description=(
                "The number of tokens in each prefix prompt.\n"
                'This is only used if "num" is greater than zero.\n'
                "Note that due to the prefix and user prompts being concatenated,\n"
                "the number of tokens in the final prompt may be off by one."
            ),
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-prefix-length",
                "--prefix-prompt-length",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = PrefixPromptDefaults.LENGTH

PromptConfig

Bases: BaseConfig

A configuration class for defining prompt related settings.

Source code in aiperf/common/config/prompt_config.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class PromptConfig(BaseConfig):
    """
    A configuration class for defining prompt related settings.
    """

    _GROUP_NAME = "Prompt"

    batch_size: Annotated[
        int,
        Field(
            description="The batch size of text requests AIPerf should send.\n"
            "This is currently supported with the embeddings and rankings endpoint types",
        ),
        cyclopts.Parameter(
            name=(
                "--prompt-batch-size",
                "--batch-size-text",  # GenAI-Perf
                "--batch-size",  # GenAI-Perf
                "-b",  # GenAI-Perf
            ),
            group=_GROUP_NAME,
        ),
    ] = PromptDefaults.BATCH_SIZE

    input_tokens: InputTokensConfig = InputTokensConfig()
    output_tokens: OutputTokensConfig = OutputTokensConfig()
    prefix_prompt: PrefixPromptConfig = PrefixPromptConfig()

aiperf.common.config.service_config

ServiceConfig

Bases: BaseSettings

Base configuration for all services. It will be provided to all services during their init function.

Source code in aiperf/common/config/service_config.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class ServiceConfig(BaseSettings):
    """Base configuration for all services. It will be provided to all services during their __init__ function."""

    model_config = SettingsConfigDict(
        env_prefix="AIPERF_",
        env_file=".env",
        env_file_encoding="utf-8",
        extra="allow",
    )

    _GROUP_NAME = "Service"

    @model_validator(mode="after")
    def validate_log_level_from_verbose_flags(self) -> Self:
        """Set log level based on verbose flags."""
        if self.extra_verbose:
            self.log_level = AIPerfLogLevel.TRACE
        elif self.verbose:
            self.log_level = AIPerfLogLevel.DEBUG
        return self

    @model_validator(mode="after")
    def validate_comm_config(self) -> Self:
        """Initialize the comm_config if it is not provided, based on the comm_backend."""
        if self.comm_config is None:
            if self.comm_backend == CommunicationBackend.ZMQ_IPC:
                self.comm_config = ZMQIPCConfig()
            elif self.comm_backend == CommunicationBackend.ZMQ_TCP:
                self.comm_config = ZMQTCPConfig()
            else:
                raise ValueError(f"Invalid communication backend: {self.comm_backend}")
        return self

    service_run_type: Annotated[
        ServiceRunType,
        Field(
            description="Type of service run (process, k8s)",
        ),
        cyclopts.Parameter(
            name=("--service-run-type", "--run-type"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.SERVICE_RUN_TYPE

    comm_backend: Annotated[
        CommunicationBackend,
        Field(
            description="Communication backend to use",
        ),
        cyclopts.Parameter(
            name=("--comm-backend"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.COMM_BACKEND

    comm_config: Annotated[
        BaseZMQCommunicationConfig | None,
        Field(
            description="Communication configuration",
        ),
        # TODO: Figure out if we need to be able to set this from the command line.
        cyclopts.Parameter(
            name=("--comm-config"),
            group="Not Supported via CLI",
        ),
    ] = ServiceDefaults.COMM_CONFIG

    heartbeat_timeout: Annotated[
        float,
        Field(
            description="Time in seconds after which a service is considered dead if no "
            "heartbeat received",
        ),
        cyclopts.Parameter(
            name=("--heartbeat-timeout"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.HEARTBEAT_TIMEOUT

    registration_timeout: Annotated[
        float,
        Field(
            description="Time in seconds to wait for all required services to register",
        ),
        cyclopts.Parameter(
            name=("--registration-timeout"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.REGISTRATION_TIMEOUT

    command_timeout: Annotated[
        float,
        Field(
            description="Default timeout for command responses",
        ),
        cyclopts.Parameter(
            name=("--command-timeout", "--command-timeout-seconds"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.COMMAND_TIMEOUT

    heartbeat_interval_seconds: Annotated[
        float,
        Field(
            description="Interval in seconds between heartbeat messages",
        ),
        cyclopts.Parameter(
            name=("--heartbeat-interval-seconds", "--heartbeat-interval"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.HEARTBEAT_INTERVAL_SECONDS

    workers: Annotated[
        WorkersConfig,
        Field(
            description="Worker configuration",
        ),
    ] = WorkersConfig()

    log_level: Annotated[
        AIPerfLogLevel,
        Field(
            description="Logging level",
        ),
        cyclopts.Parameter(
            name=("--log-level"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.LOG_LEVEL

    verbose: Annotated[
        bool,
        Field(
            description="Equivalent to --log-level DEBUG. Enables more verbose logging output, but lacks some raw message logging.",
            json_schema_extra={ADD_TO_TEMPLATE: False},
        ),
        cyclopts.Parameter(
            name=("--verbose", "-v"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.VERBOSE

    extra_verbose: Annotated[
        bool,
        Field(
            description="Equivalent to --log-level TRACE. Enables the most verbose logging output possible.",
            json_schema_extra={ADD_TO_TEMPLATE: False},
        ),
        cyclopts.Parameter(
            name=("--extra-verbose", "-vv"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.EXTRA_VERBOSE

    disable_ui: Annotated[
        bool,
        Field(
            description="Disable the UI (prints progress to the console as log messages). This is equivalent to --ui-type none.",
        ),
        cyclopts.Parameter(
            name=("--disable-ui"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.DISABLE_UI

    enable_uvloop: Annotated[
        bool,
        Field(
            description="Enable the use of uvloop instead of the default asyncio event loop",
        ),
        cyclopts.Parameter(
            name=("--enable-uvloop"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.ENABLE_UVLOOP

    # TODO: Potentially auto-scale this in the future.
    result_parser_service_count: Annotated[
        int,
        Field(
            description="Number of services to spawn for parsing inference results. The higher the request rate, the more services should be spawned.",
        ),
        cyclopts.Parameter(
            name=("--result-parser-service-count"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.RESULT_PARSER_SERVICE_COUNT

    enable_yappi: Annotated[
        bool,
        Field(
            description="[Developer use only] Enable yappi profiling (Yet Another Python Profiler) to profile AIPerf's internal python code. "
            "This can be used in the development of AIPerf in order to find performance bottlenecks across the various services. "
            "The output '*.prof' files can be viewed with snakeviz. Requires yappi and snakeviz to be installed. "
            "Run 'pip install yappi snakeviz' to install them.",
        ),
        cyclopts.Parameter(
            name=("--enable-yappi-profiling"),
            group=_GROUP_NAME,
        ),
    ] = ServiceDefaults.ENABLE_YAPPI

    debug_services: Annotated[
        set[ServiceType] | None,
        Field(
            description="List of services to enable debug logging for. Can be a comma-separated list, a single service type, "
            "or the cli flag can be used multiple times.",
        ),
        cyclopts.Parameter(
            name=("--debug-service", "--debug-services"),
            group=_GROUP_NAME,
        ),
        BeforeValidator(parse_service_types),
    ] = ServiceDefaults.DEBUG_SERVICES

    trace_services: Annotated[
        set[ServiceType] | None,
        Field(
            description="List of services to enable trace logging for. Can be a comma-separated list, a single service type, "
            "or the cli flag can be used multiple times.",
        ),
        cyclopts.Parameter(
            name=("--trace-service", "--trace-services"),
            group=_GROUP_NAME,
        ),
        BeforeValidator(parse_service_types),
    ] = ServiceDefaults.TRACE_SERVICES

validate_comm_config()

Initialize the comm_config if it is not provided, based on the comm_backend.

Source code in aiperf/common/config/service_config.py
48
49
50
51
52
53
54
55
56
57
58
@model_validator(mode="after")
def validate_comm_config(self) -> Self:
    """Initialize the comm_config if it is not provided, based on the comm_backend."""
    if self.comm_config is None:
        if self.comm_backend == CommunicationBackend.ZMQ_IPC:
            self.comm_config = ZMQIPCConfig()
        elif self.comm_backend == CommunicationBackend.ZMQ_TCP:
            self.comm_config = ZMQTCPConfig()
        else:
            raise ValueError(f"Invalid communication backend: {self.comm_backend}")
    return self

validate_log_level_from_verbose_flags()

Set log level based on verbose flags.

Source code in aiperf/common/config/service_config.py
39
40
41
42
43
44
45
46
@model_validator(mode="after")
def validate_log_level_from_verbose_flags(self) -> Self:
    """Set log level based on verbose flags."""
    if self.extra_verbose:
        self.log_level = AIPerfLogLevel.TRACE
    elif self.verbose:
        self.log_level = AIPerfLogLevel.DEBUG
    return self

aiperf.common.config.sweep_config

SweepConfig

Bases: BaseConfig

A sweep of parameters.

Source code in aiperf/common/config/sweep_config.py
 99
100
class SweepConfig(BaseConfig):
    """A sweep of parameters."""

SweepParam

Bases: BaseConfig

A parameter to be swept.

Source code in aiperf/common/config/sweep_config.py
8
9
class SweepParam(BaseConfig):
    """A parameter to be swept."""

aiperf.common.config.tokenizer_config

TokenizerConfig

Bases: BaseConfig

A configuration class for defining tokenizer related settings.

Source code in aiperf/common/config/tokenizer_config.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class TokenizerConfig(BaseConfig):
    """
    A configuration class for defining tokenizer related settings.
    """

    _GROUP_NAME = "Tokenizer"

    name: Annotated[
        str | None,
        Field(
            description=(
                "The HuggingFace tokenizer to use to interpret token metrics "
                "from prompts and responses.\nThe value can be the "
                "name of a tokenizer or the filepath of the tokenizer.\n"
                "The default value is the model name."
            ),
        ),
        cyclopts.Parameter(
            name=("--tokenizer"),
            group=_GROUP_NAME,
        ),
    ] = TokenizerDefaults.NAME

    revision: Annotated[
        str,
        Field(
            description=(
                "The specific model version to use.\n"
                "It can be a branch name, tag name, or commit ID."
            ),
        ),
        cyclopts.Parameter(
            name=("--tokenizer-revision"),
            group=_GROUP_NAME,
        ),
    ] = TokenizerDefaults.REVISION

    trust_remote_code: Annotated[
        bool,
        Field(
            description=(
                "Allows custom tokenizer to be downloaded and executed.\n"
                "This carries security risks and should only be used for repositories you trust.\n"
                "This is only necessary for custom tokenizers stored in HuggingFace Hub."
            ),
        ),
        cyclopts.Parameter(
            name=("--tokenizer-trust-remote-code"),
            group=_GROUP_NAME,
        ),
    ] = TokenizerDefaults.TRUST_REMOTE_CODE

aiperf.common.config.user_config

UserConfig

Bases: BaseConfig

A configuration class for defining top-level user settings.

Source code in aiperf/common/config/user_config.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class UserConfig(BaseConfig):
    """
    A configuration class for defining top-level user settings.
    """

    model_names: Annotated[
        list[str],
        Field(
            ...,
            description="Model name(s) to be benchmarked. Can be a comma-separated list or a single model name.",
        ),
        BeforeValidator(parse_str_or_list),
        cyclopts.Parameter(
            name=(
                "--model-names",
                "--model",  # GenAI-Perf
                "-m",  # GenAI-Perf
            ),
            group="Endpoint",
        ),
    ]

    endpoint: Annotated[
        EndPointConfig,
        Field(
            description="Endpoint configuration",
        ),
    ] = EndPointConfig()

    input: Annotated[
        InputConfig,
        Field(
            description="Input configuration",
        ),
    ] = InputConfig()

    output: Annotated[
        OutputConfig,
        Field(
            description="Output configuration",
        ),
    ] = OutputConfig()

    tokenizer: Annotated[
        TokenizerConfig,
        Field(
            description="Tokenizer configuration",
        ),
    ] = TokenizerConfig()

    loadgen: Annotated[
        LoadGeneratorConfig,
        Field(
            description="Load Generator configuration",
        ),
    ] = LoadGeneratorConfig()

aiperf.common.config.worker_config

WorkersConfig

Bases: BaseConfig

Worker configuration.

Source code in aiperf/common/config/worker_config.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class WorkersConfig(BaseConfig):
    """Worker configuration."""

    _GROUP_NAME = "Workers"

    min: Annotated[
        int | None,
        Field(
            description="Minimum number of workers to maintain",
        ),
        cyclopts.Parameter(
            name=("--workers-min", "--min-workers"),
            group=_GROUP_NAME,
        ),
    ] = WorkersDefaults.MIN

    max: Annotated[
        int | None,
        Field(
            description="Maximum number of workers to create. If not specified, the number of"
            " workers will be determined by the smaller of (concurrency + 1) and (num CPUs - 1).",
        ),
        cyclopts.Parameter(
            name=("--workers-max", "--max-workers"),
            group=_GROUP_NAME,
        ),
    ] = WorkersDefaults.MAX

    health_check_interval_seconds: Annotated[
        float,
        Field(
            description="Interval in seconds to for workers to publish their health status.",
        ),
        cyclopts.Parameter(
            name=("--workers-health-check-interval-seconds"),
            group=_GROUP_NAME,
        ),
    ] = WorkersDefaults.HEALTH_CHECK_INTERVAL_SECONDS

aiperf.common.config.zmq_config

BaseZMQCommunicationConfig

Bases: BaseModel, ABC

Configuration for ZMQ communication.

Source code in aiperf/common/config/zmq_config.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class BaseZMQCommunicationConfig(BaseModel, ABC):
    """Configuration for ZMQ communication."""

    # Proxy config options to be overridden by subclasses
    event_bus_proxy_config: ClassVar[BaseZMQProxyConfig]
    dataset_manager_proxy_config: ClassVar[BaseZMQProxyConfig]
    raw_inference_proxy_config: ClassVar[BaseZMQProxyConfig]

    @property
    @abstractmethod
    def records_push_pull_address(self) -> str:
        """Get the inference push/pull address based on protocol configuration."""

    @property
    @abstractmethod
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""

    @property
    @abstractmethod
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""

    def get_address(self, address_type: CommunicationClientAddressType) -> str:
        """Get the actual address based on the address type."""
        address_map = {
            CommunicationClientAddressType.EVENT_BUS_PROXY_FRONTEND: self.event_bus_proxy_config.frontend_address,
            CommunicationClientAddressType.EVENT_BUS_PROXY_BACKEND: self.event_bus_proxy_config.backend_address,
            CommunicationClientAddressType.DATASET_MANAGER_PROXY_FRONTEND: self.dataset_manager_proxy_config.frontend_address,
            CommunicationClientAddressType.DATASET_MANAGER_PROXY_BACKEND: self.dataset_manager_proxy_config.backend_address,
            CommunicationClientAddressType.CREDIT_DROP: self.credit_drop_address,
            CommunicationClientAddressType.CREDIT_RETURN: self.credit_return_address,
            CommunicationClientAddressType.RECORDS: self.records_push_pull_address,
            CommunicationClientAddressType.RAW_INFERENCE_PROXY_FRONTEND: self.raw_inference_proxy_config.frontend_address,
            CommunicationClientAddressType.RAW_INFERENCE_PROXY_BACKEND: self.raw_inference_proxy_config.backend_address,
        }

        if address_type not in address_map:
            raise ValueError(f"Invalid address type: {address_type}")

        return address_map[address_type]

credit_drop_address abstractmethod property

Get the credit drop address based on protocol configuration.

credit_return_address abstractmethod property

Get the credit return address based on protocol configuration.

records_push_pull_address abstractmethod property

Get the inference push/pull address based on protocol configuration.

get_address(address_type)

Get the actual address based on the address type.

Source code in aiperf/common/config/zmq_config.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def get_address(self, address_type: CommunicationClientAddressType) -> str:
    """Get the actual address based on the address type."""
    address_map = {
        CommunicationClientAddressType.EVENT_BUS_PROXY_FRONTEND: self.event_bus_proxy_config.frontend_address,
        CommunicationClientAddressType.EVENT_BUS_PROXY_BACKEND: self.event_bus_proxy_config.backend_address,
        CommunicationClientAddressType.DATASET_MANAGER_PROXY_FRONTEND: self.dataset_manager_proxy_config.frontend_address,
        CommunicationClientAddressType.DATASET_MANAGER_PROXY_BACKEND: self.dataset_manager_proxy_config.backend_address,
        CommunicationClientAddressType.CREDIT_DROP: self.credit_drop_address,
        CommunicationClientAddressType.CREDIT_RETURN: self.credit_return_address,
        CommunicationClientAddressType.RECORDS: self.records_push_pull_address,
        CommunicationClientAddressType.RAW_INFERENCE_PROXY_FRONTEND: self.raw_inference_proxy_config.frontend_address,
        CommunicationClientAddressType.RAW_INFERENCE_PROXY_BACKEND: self.raw_inference_proxy_config.backend_address,
    }

    if address_type not in address_map:
        raise ValueError(f"Invalid address type: {address_type}")

    return address_map[address_type]

BaseZMQProxyConfig

Bases: BaseModel, ABC

Configuration Protocol for ZMQ Proxy.

Source code in aiperf/common/config/zmq_config.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class BaseZMQProxyConfig(BaseModel, ABC):
    """Configuration Protocol for ZMQ Proxy."""

    @property
    @abstractmethod
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""

    @property
    @abstractmethod
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""

    @property
    @abstractmethod
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""

    @property
    @abstractmethod
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""

backend_address abstractmethod property

Get the backend address based on protocol configuration.

capture_address abstractmethod property

Get the capture address based on protocol configuration.

control_address abstractmethod property

Get the control address based on protocol configuration.

frontend_address abstractmethod property

Get the frontend address based on protocol configuration.

ZMQIPCConfig

Bases: BaseZMQCommunicationConfig

Configuration for IPC transport.

Source code in aiperf/common/config/zmq_config.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class ZMQIPCConfig(BaseZMQCommunicationConfig):
    """Configuration for IPC transport."""

    path: str = Field(default="/tmp/aiperf", description="Path for IPC sockets")
    dataset_manager_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="dataset_manager_proxy"),
        description="Configuration for the ZMQ Dealer Router Proxy. If provided, the proxy will be created and started.",
    )
    event_bus_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="event_bus_proxy"),
        description="Configuration for the ZMQ XPUB/XSUB Proxy. If provided, the proxy will be created and started.",
    )
    raw_inference_proxy_config: ZMQIPCProxyConfig = Field(  # type: ignore
        default=ZMQIPCProxyConfig(name="raw_inference_proxy"),
        description="Configuration for the ZMQ Push/Pull Proxy. If provided, the proxy will be created and started.",
    )

    @property
    def records_push_pull_address(self) -> str:
        """Get the records push/pull address based on protocol configuration."""
        return f"ipc://{self.path}/records_push_pull.ipc"

    @property
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""
        return f"ipc://{self.path}/credit_drop.ipc"

    @property
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""
        return f"ipc://{self.path}/credit_return.ipc"

credit_drop_address property

Get the credit drop address based on protocol configuration.

credit_return_address property

Get the credit return address based on protocol configuration.

records_push_pull_address property

Get the records push/pull address based on protocol configuration.

ZMQIPCProxyConfig

Bases: BaseZMQProxyConfig

Configuration for IPC proxy.

Source code in aiperf/common/config/zmq_config.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ZMQIPCProxyConfig(BaseZMQProxyConfig):
    """Configuration for IPC proxy."""

    path: str = Field(default="/tmp/aiperf", description="Path for IPC sockets")
    name: str = Field(default="proxy", description="Name for IPC sockets")
    enable_control: bool = Field(default=False, description="Enable control socket")
    enable_capture: bool = Field(default=False, description="Enable capture socket")

    @property
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""
        return f"ipc://{self.path}/{self.name}_frontend.ipc"

    @property
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""
        return f"ipc://{self.path}/{self.name}_backend.ipc"

    @property
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""
        return (
            f"ipc://{self.path}/{self.name}_control.ipc"
            if self.enable_control
            else None
        )

    @property
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""
        return (
            f"ipc://{self.path}/{self.name}_capture.ipc"
            if self.enable_capture
            else None
        )

backend_address property

Get the backend address based on protocol configuration.

capture_address property

Get the capture address based on protocol configuration.

control_address property

Get the control address based on protocol configuration.

frontend_address property

Get the frontend address based on protocol configuration.

ZMQTCPConfig

Bases: BaseZMQCommunicationConfig

Configuration for TCP transport.

Source code in aiperf/common/config/zmq_config.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class ZMQTCPConfig(BaseZMQCommunicationConfig):
    """Configuration for TCP transport."""

    host: str = Field(
        default="0.0.0.0",
        description="Host address for TCP connections",
    )
    records_push_pull_port: int = Field(
        default=5557, description="Port for inference push/pull messages"
    )
    credit_drop_port: int = Field(
        default=5562, description="Port for credit drop operations"
    )
    credit_return_port: int = Field(
        default=5563, description="Port for credit return operations"
    )
    dataset_manager_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5661,
            backend_port=5662,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )
    event_bus_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5663,
            backend_port=5664,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )
    raw_inference_proxy_config: ZMQTCPProxyConfig = Field(  # type: ignore
        default=ZMQTCPProxyConfig(
            frontend_port=5665,
            backend_port=5666,
        ),
        description="Configuration for the ZMQ Proxy. If provided, the proxy will be created and started.",
    )

    @property
    def records_push_pull_address(self) -> str:
        """Get the records push/pull address based on protocol configuration."""
        return f"tcp://{self.host}:{self.records_push_pull_port}"

    @property
    def credit_drop_address(self) -> str:
        """Get the credit drop address based on protocol configuration."""
        return f"tcp://{self.host}:{self.credit_drop_port}"

    @property
    def credit_return_address(self) -> str:
        """Get the credit return address based on protocol configuration."""
        return f"tcp://{self.host}:{self.credit_return_port}"

credit_drop_address property

Get the credit drop address based on protocol configuration.

credit_return_address property

Get the credit return address based on protocol configuration.

records_push_pull_address property

Get the records push/pull address based on protocol configuration.

ZMQTCPProxyConfig

Bases: BaseZMQProxyConfig

Configuration for TCP proxy.

Source code in aiperf/common/config/zmq_config.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class ZMQTCPProxyConfig(BaseZMQProxyConfig):
    """Configuration for TCP proxy."""

    host: str = Field(
        default="0.0.0.0",
        description="Host address for TCP connections",
    )
    frontend_port: int = Field(
        default=15555, description="Port for frontend address for proxy"
    )
    backend_port: int = Field(
        default=15556, description="Port for backend address for proxy"
    )
    control_port: int | None = Field(
        default=None, description="Port for control address for proxy"
    )
    capture_port: int | None = Field(
        default=None, description="Port for capture address for proxy"
    )

    @property
    def frontend_address(self) -> str:
        """Get the frontend address based on protocol configuration."""
        return f"tcp://{self.host}:{self.frontend_port}"

    @property
    def backend_address(self) -> str:
        """Get the backend address based on protocol configuration."""
        return f"tcp://{self.host}:{self.backend_port}"

    @property
    def control_address(self) -> str | None:
        """Get the control address based on protocol configuration."""
        return f"tcp://{self.host}:{self.control_port}" if self.control_port else None

    @property
    def capture_address(self) -> str | None:
        """Get the capture address based on protocol configuration."""
        return f"tcp://{self.host}:{self.capture_port}" if self.capture_port else None

backend_address property

Get the backend address based on protocol configuration.

capture_address property

Get the capture address based on protocol configuration.

control_address property

Get the control address based on protocol configuration.

frontend_address property

Get the frontend address based on protocol configuration.

aiperf.common.constants

DEFAULT_COMMS_REQUEST_TIMEOUT = 10.0 module-attribute

Default timeout for requests from req_clients to rep_clients in seconds.

TASK_CANCEL_TIMEOUT_LONG = 5.0 module-attribute

Maximum time to wait for complex tasks to complete when cancelling them (like parent tasks).

TASK_CANCEL_TIMEOUT_SHORT = 2.0 module-attribute

Maximum time to wait for simple tasks to complete when cancelling them.

aiperf.common.enums.base_enums

CaseInsensitiveStrEnum

Bases: str, Enum

CaseInsensitiveStrEnum is a custom enumeration class that extends str and Enum to provide case-insensitive lookup functionality for its members.

Source code in aiperf/common/enums/base_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class CaseInsensitiveStrEnum(str, Enum):
    """
    CaseInsensitiveStrEnum is a custom enumeration class that extends `str` and `Enum` to provide case-insensitive
    lookup functionality for its members.
    """

    def __str__(self) -> str:
        return self.value

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}.{self.name}"

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, str):
            return self.value.lower() == other.lower()
        return super().__eq__(other)

    def __hash__(self) -> int:
        return hash(self.value.lower())

    @classmethod
    def _missing_(cls, value):
        """
        Handles cases where a value is not directly found in the enumeration.

        This method is called when an attempt is made to access an enumeration
        member using a value that does not directly match any of the defined
        members. It provides custom logic to handle such cases.

        Returns:
            The matching enumeration member if a case-insensitive match is found
            for string values; otherwise, returns None.
        """
        if isinstance(value, str):
            for member in cls:
                if member.value.lower() == value.lower():
                    return member
        return None

aiperf.common.enums.benchmark_suite_enums

BenchmarkSuiteCompletionTrigger

Bases: CaseInsensitiveStrEnum

Determines how the suite completion is determined in order to know how to track the progress.

Source code in aiperf/common/enums/benchmark_suite_enums.py
 7
 8
 9
10
11
class BenchmarkSuiteCompletionTrigger(CaseInsensitiveStrEnum):
    """Determines how the suite completion is determined in order to know how to track the progress."""

    COMPLETED_PROFILES = "completed_profiles"
    """The suite will run until all profiles are completed."""

COMPLETED_PROFILES = 'completed_profiles' class-attribute instance-attribute

The suite will run until all profiles are completed.

BenchmarkSuiteType

Bases: CaseInsensitiveStrEnum

Determines the type of suite to know how to track the progress.

Source code in aiperf/common/enums/benchmark_suite_enums.py
19
20
21
22
23
class BenchmarkSuiteType(CaseInsensitiveStrEnum):
    """Determines the type of suite to know how to track the progress."""

    SINGLE_PROFILE = "single_profile"
    """A suite with a single profile run."""

SINGLE_PROFILE = 'single_profile' class-attribute instance-attribute

A suite with a single profile run.

aiperf.common.enums.command_enums

CommandResponseStatus

Bases: CaseInsensitiveStrEnum

Status of a command response.

Source code in aiperf/common/enums/command_enums.py
35
36
37
38
39
class CommandResponseStatus(CaseInsensitiveStrEnum):
    """Status of a command response."""

    SUCCESS = "success"
    FAILURE = "failure"

CommandType

Bases: CaseInsensitiveStrEnum

List of commands that the SystemController can send to component services.

Source code in aiperf/common/enums/command_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class CommandType(CaseInsensitiveStrEnum):
    """List of commands that the SystemController can send to component services."""

    SHUTDOWN = "shutdown"
    """A command sent to shutdown a service. This will stop the service gracefully
    no matter what state it is in."""

    PROCESS_RECORDS = "process_records"
    """A command sent to process records. This will process the records and return
    the services to their pre-record processing state."""

    PROFILE_CONFIGURE = "profile_configure"
    """A command sent to configure a service in preparation for a profile run. This will
    override the current configuration."""

    PROFILE_START = "profile_start"
    """A command sent to indicate that a service should begin profiling using the
    current configuration."""

    PROFILE_STOP = "profile_stop"
    """A command sent to indicate that a service should stop doing profile related
    work, as the profile run is complete."""

    PROFILE_CANCEL = "profile_cancel"
    """A command sent to cancel a profile run. This will stop the current profile run and
    process the partial results."""

PROCESS_RECORDS = 'process_records' class-attribute instance-attribute

A command sent to process records. This will process the records and return the services to their pre-record processing state.

PROFILE_CANCEL = 'profile_cancel' class-attribute instance-attribute

A command sent to cancel a profile run. This will stop the current profile run and process the partial results.

PROFILE_CONFIGURE = 'profile_configure' class-attribute instance-attribute

A command sent to configure a service in preparation for a profile run. This will override the current configuration.

PROFILE_START = 'profile_start' class-attribute instance-attribute

A command sent to indicate that a service should begin profiling using the current configuration.

PROFILE_STOP = 'profile_stop' class-attribute instance-attribute

A command sent to indicate that a service should stop doing profile related work, as the profile run is complete.

SHUTDOWN = 'shutdown' class-attribute instance-attribute

A command sent to shutdown a service. This will stop the service gracefully no matter what state it is in.

aiperf.common.enums.communication_enums

CommunicationBackend

Bases: CaseInsensitiveStrEnum

Supported communication backends.

Source code in aiperf/common/enums/communication_enums.py
 7
 8
 9
10
11
12
13
14
class CommunicationBackend(CaseInsensitiveStrEnum):
    """Supported communication backends."""

    ZMQ_TCP = "zmq_tcp"
    """ZeroMQ backend using TCP sockets."""

    ZMQ_IPC = "zmq_ipc"
    """ZeroMQ backend using IPC sockets."""

ZMQ_IPC = 'zmq_ipc' class-attribute instance-attribute

ZeroMQ backend using IPC sockets.

ZMQ_TCP = 'zmq_tcp' class-attribute instance-attribute

ZeroMQ backend using TCP sockets.

CommunicationClientAddressType

Bases: CaseInsensitiveStrEnum

Enum for specifying the address type for communication clients. This is used to lookup the address in the communication config.

Source code in aiperf/common/enums/communication_enums.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class CommunicationClientAddressType(CaseInsensitiveStrEnum):
    """Enum for specifying the address type for communication clients.
    This is used to lookup the address in the communication config."""

    EVENT_BUS_PROXY_FRONTEND = "event_bus_proxy_frontend"
    """Frontend address for services to publish messages to."""

    EVENT_BUS_PROXY_BACKEND = "event_bus_proxy_backend"
    """Backend address for services to subscribe to messages."""

    CREDIT_DROP = "credit_drop"
    """Address to send CreditDrop messages from the TimingManager to the Worker."""

    CREDIT_RETURN = "credit_return"
    """Address to send CreditReturn messages from the Worker to the TimingManager."""

    RECORDS = "records"
    """Address to send parsed records from InferenceParser to RecordManager."""

    DATASET_MANAGER_PROXY_FRONTEND = "dataset_manager_proxy_frontend"
    """Frontend address for sending requests to the DatasetManager."""

    DATASET_MANAGER_PROXY_BACKEND = "dataset_manager_proxy_backend"
    """Backend address for the DatasetManager to receive requests from clients."""

    RAW_INFERENCE_PROXY_FRONTEND = "raw_inference_proxy_frontend"
    """Frontend address for sending raw inference messages to the InferenceParser from Workers."""

    RAW_INFERENCE_PROXY_BACKEND = "raw_inference_proxy_backend"
    """Backend address for the InferenceParser to receive raw inference messages from Workers."""

CREDIT_DROP = 'credit_drop' class-attribute instance-attribute

Address to send CreditDrop messages from the TimingManager to the Worker.

CREDIT_RETURN = 'credit_return' class-attribute instance-attribute

Address to send CreditReturn messages from the Worker to the TimingManager.

DATASET_MANAGER_PROXY_BACKEND = 'dataset_manager_proxy_backend' class-attribute instance-attribute

Backend address for the DatasetManager to receive requests from clients.

DATASET_MANAGER_PROXY_FRONTEND = 'dataset_manager_proxy_frontend' class-attribute instance-attribute

Frontend address for sending requests to the DatasetManager.

EVENT_BUS_PROXY_BACKEND = 'event_bus_proxy_backend' class-attribute instance-attribute

Backend address for services to subscribe to messages.

EVENT_BUS_PROXY_FRONTEND = 'event_bus_proxy_frontend' class-attribute instance-attribute

Frontend address for services to publish messages to.

RAW_INFERENCE_PROXY_BACKEND = 'raw_inference_proxy_backend' class-attribute instance-attribute

Backend address for the InferenceParser to receive raw inference messages from Workers.

RAW_INFERENCE_PROXY_FRONTEND = 'raw_inference_proxy_frontend' class-attribute instance-attribute

Frontend address for sending raw inference messages to the InferenceParser from Workers.

RECORDS = 'records' class-attribute instance-attribute

Address to send parsed records from InferenceParser to RecordManager.

CommunicationClientType

Bases: CaseInsensitiveStrEnum

Enum for specifying the communication client type for communication clients.

Source code in aiperf/common/enums/communication_enums.py
17
18
19
20
21
22
23
24
25
class CommunicationClientType(CaseInsensitiveStrEnum):
    """Enum for specifying the communication client type for communication clients."""

    PUB = "pub"
    SUB = "sub"
    PUSH = "push"
    PULL = "pull"
    REQUEST = "request"
    REPLY = "reply"

ZMQProxyType

Bases: CaseInsensitiveStrEnum

Types of ZMQ proxies.

Source code in aiperf/common/enums/communication_enums.py
60
61
62
63
64
65
class ZMQProxyType(CaseInsensitiveStrEnum):
    """Types of ZMQ proxies."""

    DEALER_ROUTER = "dealer_router"
    XPUB_XSUB = "xpub_xsub"
    PUSH_PULL = "push_pull"

aiperf.common.enums.data_exporter_enums

aiperf.common.enums.dataset_enums

AudioFormat

Bases: CaseInsensitiveStrEnum

Types of audio formats supported by AIPerf.

Source code in aiperf/common/enums/dataset_enums.py
34
35
36
37
38
class AudioFormat(CaseInsensitiveStrEnum):
    """Types of audio formats supported by AIPerf."""

    WAV = "wav"
    MP3 = "mp3"

ComposerType

Bases: CaseInsensitiveStrEnum

The type of composer to use for the dataset.

Source code in aiperf/common/enums/dataset_enums.py
 7
 8
 9
10
11
12
13
14
class ComposerType(CaseInsensitiveStrEnum):
    """
    The type of composer to use for the dataset.
    """

    SYNTHETIC = "synthetic"
    CUSTOM = "custom"
    PUBLIC_DATASET = "public_dataset"

CustomDatasetType

Bases: CaseInsensitiveStrEnum

Defines the type of JSONL custom dataset from the user.

Source code in aiperf/common/enums/dataset_enums.py
17
18
19
20
21
22
23
class CustomDatasetType(CaseInsensitiveStrEnum):
    """Defines the type of JSONL custom dataset from the user."""

    SINGLE_TURN = "single_turn"
    MULTI_TURN = "multi_turn"
    RANDOM_POOL = "random_pool"
    MOONCAKE_TRACE = "mooncake_trace"

ImageFormat

Bases: CaseInsensitiveStrEnum

Types of image formats supported by AIPerf.

Source code in aiperf/common/enums/dataset_enums.py
26
27
28
29
30
31
class ImageFormat(CaseInsensitiveStrEnum):
    """Types of image formats supported by AIPerf."""

    PNG = "png"
    JPEG = "jpeg"
    RANDOM = "random"

PromptSource

Bases: CaseInsensitiveStrEnum

Source of prompts for the model.

Source code in aiperf/common/enums/dataset_enums.py
41
42
43
44
45
46
class PromptSource(CaseInsensitiveStrEnum):
    """Source of prompts for the model."""

    SYNTHETIC = "synthetic"
    FILE = "file"
    PAYLOAD = "payload"

aiperf.common.enums.endpoints_enums

EndpointType

Bases: CaseInsensitiveStrEnum

Endpoint types.

These determine the format of request payload to send to the model.

Similar to endpoint_type_map and OutputFormat from genai-perf.

Source code in aiperf/common/enums/endpoints_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class EndpointType(CaseInsensitiveStrEnum):
    """Endpoint types.

    These determine the format of request payload to send to the model.

    Similar to `endpoint_type_map` and `OutputFormat` from `genai-perf`.
    """

    OPENAI_CHAT_COMPLETIONS = "chat"
    OPENAI_COMPLETIONS = "completions"
    # OPENAI_EMBEDDINGS = "embeddings"
    # OPENAI_MULTIMODAL = "multimodal"
    OPENAI_RESPONSES = "responses"

    # TODO: implement other endpoints
    # HUGGINGFACE_GENERATE = "generate"

    # DYNAMIC_GRPC = "dynamic_grpc"
    # NVCLIP = "nvclip"
    # TEMPLATE = "template"

    # RANKINGS = "rankings"
    # IMAGE_RETRIEVAL = "image_retrieval"

    # TENSORRTLLM = "tensorrtllm"
    # TENSORRTLLM_ENGINE = "tensorrtllm_engine"

    # TRITON_GENERATE = "triton_generate"

    # DYNAMO_ENGINE = "dynamo_engine"

    def endpoint_path(self) -> str | None:
        """Get the endpoint path for the endpoint type."""
        endpoint_path_map = {
            # OpenAI endpoints
            EndpointType.OPENAI_CHAT_COMPLETIONS: "/v1/chat/completions",
            # EndpointType.OPENAI_MULTIMODAL: "/v1/chat/completions",
            EndpointType.OPENAI_COMPLETIONS: "/v1/completions",
            # EndpointType.OPENAI_EMBEDDINGS: "/v1/embeddings",
            EndpointType.OPENAI_RESPONSES: "/v1/responses",
            # TODO: implement other endpoints
            # Other
            # EndpointType.NVCLIP: "/v1/embeddings",
            # EndpointType.HUGGINGFACE_GENERATE: "/",  # HuggingFace TGI only exposes root endpoint
            # EndpointType.RANKINGS: "/v1/ranking",  # TODO: Not implemented yet
            # EndpointType.IMAGE_RETRIEVAL: "/v1/infer",  # TODO: Not implemented yet
            # EndpointType.TRITON_GENERATE: "/v2/models/{MODEL_NAME}/generate",  # TODO: Not implemented yet
            # # These endpoints do not have a specific path
            # EndpointType.DYNAMIC_GRPC: None,  # TODO: Not implemented yet
            # EndpointType.TEMPLATE: None,  # TODO: Not implemented yet
            # EndpointType.TENSORRTLLM: None,  # TODO: Not implemented yet
            # EndpointType.TENSORRTLLM_ENGINE: None,  # TODO: Not implemented yet
            # EndpointType.DYNAMO_ENGINE: None,  # TODO: Not implemented yet
        }

        if self not in endpoint_path_map:
            raise NotImplementedError(f"Endpoint not implemented for {self}")

        return endpoint_path_map[self]

    def response_payload_type(self) -> "ResponsePayloadType":
        """Get the response payload type for the request payload type."""
        return ResponsePayloadType.from_endpoint_type(self)

endpoint_path()

Get the endpoint path for the endpoint type.

Source code in aiperf/common/enums/endpoints_enums.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def endpoint_path(self) -> str | None:
    """Get the endpoint path for the endpoint type."""
    endpoint_path_map = {
        # OpenAI endpoints
        EndpointType.OPENAI_CHAT_COMPLETIONS: "/v1/chat/completions",
        # EndpointType.OPENAI_MULTIMODAL: "/v1/chat/completions",
        EndpointType.OPENAI_COMPLETIONS: "/v1/completions",
        # EndpointType.OPENAI_EMBEDDINGS: "/v1/embeddings",
        EndpointType.OPENAI_RESPONSES: "/v1/responses",
        # TODO: implement other endpoints
        # Other
        # EndpointType.NVCLIP: "/v1/embeddings",
        # EndpointType.HUGGINGFACE_GENERATE: "/",  # HuggingFace TGI only exposes root endpoint
        # EndpointType.RANKINGS: "/v1/ranking",  # TODO: Not implemented yet
        # EndpointType.IMAGE_RETRIEVAL: "/v1/infer",  # TODO: Not implemented yet
        # EndpointType.TRITON_GENERATE: "/v2/models/{MODEL_NAME}/generate",  # TODO: Not implemented yet
        # # These endpoints do not have a specific path
        # EndpointType.DYNAMIC_GRPC: None,  # TODO: Not implemented yet
        # EndpointType.TEMPLATE: None,  # TODO: Not implemented yet
        # EndpointType.TENSORRTLLM: None,  # TODO: Not implemented yet
        # EndpointType.TENSORRTLLM_ENGINE: None,  # TODO: Not implemented yet
        # EndpointType.DYNAMO_ENGINE: None,  # TODO: Not implemented yet
    }

    if self not in endpoint_path_map:
        raise NotImplementedError(f"Endpoint not implemented for {self}")

    return endpoint_path_map[self]

response_payload_type()

Get the response payload type for the request payload type.

Source code in aiperf/common/enums/endpoints_enums.py
67
68
69
def response_payload_type(self) -> "ResponsePayloadType":
    """Get the response payload type for the request payload type."""
    return ResponsePayloadType.from_endpoint_type(self)

ResponsePayloadType

Bases: CaseInsensitiveStrEnum

Response payload types.

These determine the format of the response payload that the model will return.

Equivalent to output_format from genai-perf.

Source code in aiperf/common/enums/endpoints_enums.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class ResponsePayloadType(CaseInsensitiveStrEnum):
    """Response payload types.

    These determine the format of the response payload that the model will return.

    Equivalent to `output_format` from `genai-perf`.
    """

    OPENAI_CHAT_COMPLETIONS = "openai_chat_completions"
    OPENAI_COMPLETIONS = "openai_completions"
    # OPENAI_EMBEDDINGS = "openai_embeddings"
    # OPENAI_MULTIMODAL = "openai_multimodal"
    OPENAI_RESPONSES = "openai_responses"

    # TODO: implement other endpoints
    # HUGGINGFACE_GENERATE = "huggingface_generate"

    # RANKINGS = "rankings"

    # IMAGE_RETRIEVAL = "image_retrieval"

    @classmethod
    def from_endpoint_type(cls, endpoint_type: EndpointType) -> "ResponsePayloadType":
        """Get the response payload type for the endpoint type."""
        endpoint_to_payload_map = {
            EndpointType.OPENAI_CHAT_COMPLETIONS: ResponsePayloadType.OPENAI_CHAT_COMPLETIONS,
            # EndpointType.OPENAI_MULTIMODAL: ResponsePayloadType.OPENAI_CHAT_COMPLETIONS,
            EndpointType.OPENAI_COMPLETIONS: ResponsePayloadType.OPENAI_COMPLETIONS,
            # EndpointType.OPENAI_EMBEDDINGS: ResponsePayloadType.OPENAI_EMBEDDINGS,
            EndpointType.OPENAI_RESPONSES: ResponsePayloadType.OPENAI_RESPONSES,
            # TODO: implement other endpoints
            # EndpointType.HUGGINGFACE_GENERATE: ResponsePayloadType.HUGGINGFACE_GENERATE,
            # EndpointType.RANKINGS: ResponsePayloadType.RANKINGS,
            # EndpointType.IMAGE_RETRIEVAL: ResponsePayloadType.IMAGE_RETRIEVAL,
        }

        if endpoint_type not in endpoint_to_payload_map:
            raise NotImplementedError(
                f"Payload type not implemented for {endpoint_type}"
            )

        return endpoint_to_payload_map[endpoint_type]

from_endpoint_type(endpoint_type) classmethod

Get the response payload type for the endpoint type.

Source code in aiperf/common/enums/endpoints_enums.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@classmethod
def from_endpoint_type(cls, endpoint_type: EndpointType) -> "ResponsePayloadType":
    """Get the response payload type for the endpoint type."""
    endpoint_to_payload_map = {
        EndpointType.OPENAI_CHAT_COMPLETIONS: ResponsePayloadType.OPENAI_CHAT_COMPLETIONS,
        # EndpointType.OPENAI_MULTIMODAL: ResponsePayloadType.OPENAI_CHAT_COMPLETIONS,
        EndpointType.OPENAI_COMPLETIONS: ResponsePayloadType.OPENAI_COMPLETIONS,
        # EndpointType.OPENAI_EMBEDDINGS: ResponsePayloadType.OPENAI_EMBEDDINGS,
        EndpointType.OPENAI_RESPONSES: ResponsePayloadType.OPENAI_RESPONSES,
        # TODO: implement other endpoints
        # EndpointType.HUGGINGFACE_GENERATE: ResponsePayloadType.HUGGINGFACE_GENERATE,
        # EndpointType.RANKINGS: ResponsePayloadType.RANKINGS,
        # EndpointType.IMAGE_RETRIEVAL: ResponsePayloadType.IMAGE_RETRIEVAL,
    }

    if endpoint_type not in endpoint_to_payload_map:
        raise NotImplementedError(
            f"Payload type not implemented for {endpoint_type}"
        )

    return endpoint_to_payload_map[endpoint_type]

aiperf.common.enums.logging_enums

AIPerfLogLevel

Bases: CaseInsensitiveStrEnum

Log levels for AIPerfLogger.

Source code in aiperf/common/enums/logging_enums.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class AIPerfLogLevel(CaseInsensitiveStrEnum):
    """Log levels for AIPerfLogger."""

    TRACE = "TRACE"
    DEBUG = "DEBUG"
    INFO = "INFO"
    NOTICE = "NOTICE"
    WARNING = "WARNING"
    SUCCESS = "SUCCESS"
    ERROR = "ERROR"
    CRITICAL = "CRITICAL"

    @property
    def level(self) -> int:
        """Get the integer level equivalent."""
        return _LEVEL_MAP[self]

level property

Get the integer level equivalent.

aiperf.common.enums.measurement_enums

aiperf.common.enums.message_enums

MessageType

Bases: CaseInsensitiveStrEnum

The various types of messages that can be sent between services.

The message type is used to determine what Pydantic model the message maps to, based on the message_type field in the message model. For detailed explanations of each message type, go to its definition in :mod:aiperf.common.messages.

Source code in aiperf/common/enums/message_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class MessageType(CaseInsensitiveStrEnum):
    """The various types of messages that can be sent between services.

    The message type is used to determine what Pydantic model the message maps to,
    based on the message_type field in the message model. For detailed explanations
    of each message type, go to its definition in :mod:`aiperf.common.messages`.
    """

    COMMAND = "command"
    COMMAND_RESPONSE = "command_response"
    CONVERSATION_REQUEST = "conversation_request"
    CONVERSATION_RESPONSE = "conversation_response"
    CONVERSATION_TURN_REQUEST = "conversation_turn_request"
    CONVERSATION_TURN_RESPONSE = "conversation_turn_response"
    CREDITS_COMPLETE = "credits_complete"
    CREDIT_DROP = "credit_drop"
    CREDIT_PHASE_COMPLETE = "credit_phase_complete"
    CREDIT_PHASE_PROGRESS = "credit_phase_progress"
    CREDIT_PHASE_SENDING_COMPLETE = "credit_phase_sending_complete"
    CREDIT_PHASE_START = "credit_phase_start"
    CREDIT_RETURN = "credit_return"
    DATASET_CONFIGURED_NOTIFICATION = "dataset_configured_notification"
    DATASET_TIMING_REQUEST = "dataset_timing_request"
    DATASET_TIMING_RESPONSE = "dataset_timing_response"
    ERROR = "error"
    HEARTBEAT = "heartbeat"
    INFERENCE_RESULTS = "inference_results"
    NOTIFICATION = "notification"
    PARSED_INFERENCE_RESULTS = "parsed_inference_results"
    PROCESSING_STATS = "processing_stats"
    PROCESS_RECORDS_REQUEST = "process_records_request"
    PROCESS_RECORDS_RESPONSE = "process_records_response"
    PROFILE_ERROR = "profile_error"
    PROFILE_PROGRESS = "profile_progress"
    PROFILE_RESULTS = "profile_results"
    REGISTRATION = "registration"
    SERVICE_ERROR = "service_error"
    STATUS = "status"
    SWEEP_BEGIN = "sweep_begin"
    SWEEP_CONFIGURE = "sweep_configure"
    SWEEP_END = "sweep_end"
    SWEEP_ERROR = "sweep_error"
    SWEEP_PROGRESS = "sweep_progress"
    SWEEP_RESULTS = "sweep_results"
    UNKNOWN = "unknown"
    WORKER_HEALTH = "worker_health"

NotificationType

Bases: CaseInsensitiveStrEnum

Types of notifications that can be sent to other services.

Source code in aiperf/common/enums/message_enums.py
55
56
57
58
59
class NotificationType(CaseInsensitiveStrEnum):
    """Types of notifications that can be sent to other services."""

    DATASET_CONFIGURED = "dataset_configured"
    """A notification sent to notify other services that the dataset has been configured."""

DATASET_CONFIGURED = 'dataset_configured' class-attribute instance-attribute

A notification sent to notify other services that the dataset has been configured.

aiperf.common.enums.metric_enums

MetricTimeType

Bases: CaseInsensitiveStrEnum

Defines the time types for metrics.

Source code in aiperf/common/enums/metric_enums.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class MetricTimeType(CaseInsensitiveStrEnum):
    """Defines the time types for metrics."""

    NANOSECONDS = "nanoseconds"
    MILLISECONDS = "milliseconds"
    SECONDS = "seconds"

    def short_name(self) -> str:
        """Get the short name for the time type."""
        _short_name_map = {
            MetricTimeType.NANOSECONDS: "ns",
            MetricTimeType.MILLISECONDS: "ms",
            MetricTimeType.SECONDS: "s",
        }
        return _short_name_map[self]

short_name()

Get the short name for the time type.

Source code in aiperf/common/enums/metric_enums.py
16
17
18
19
20
21
22
23
def short_name(self) -> str:
    """Get the short name for the time type."""
    _short_name_map = {
        MetricTimeType.NANOSECONDS: "ns",
        MetricTimeType.MILLISECONDS: "ms",
        MetricTimeType.SECONDS: "s",
    }
    return _short_name_map[self]

aiperf.common.enums.model_enums

Modality

Bases: CaseInsensitiveStrEnum

Modality of the model. Can be used to determine the type of data to send to the model in conjunction with the ModelSelectionStrategy.MODALITY_AWARE.

Source code in aiperf/common/enums/model_enums.py
 7
 8
 9
10
11
12
13
14
15
16
class Modality(CaseInsensitiveStrEnum):
    """Modality of the model. Can be used to determine the type of data to send to the model in
    conjunction with the ModelSelectionStrategy.MODALITY_AWARE."""

    TEXT = "text"
    IMAGE = "image"
    AUDIO = "audio"
    VIDEO = "video"
    MULTIMODAL = "multimodal"
    CUSTOM = "custom"

ModelSelectionStrategy

Bases: CaseInsensitiveStrEnum

Strategy for selecting the model to use for the request.

Source code in aiperf/common/enums/model_enums.py
19
20
21
22
23
24
class ModelSelectionStrategy(CaseInsensitiveStrEnum):
    """Strategy for selecting the model to use for the request."""

    ROUND_ROBIN = "round_robin"
    RANDOM = "random"
    MODALITY_AWARE = "modality_aware"

aiperf.common.enums.post_processor_enums

StreamingPostProcessorType

Bases: CaseInsensitiveStrEnum

Type of response streamer.

Source code in aiperf/common/enums/post_processor_enums.py
11
12
13
14
15
16
17
18
19
20
21
class StreamingPostProcessorType(CaseInsensitiveStrEnum):
    """Type of response streamer."""

    PROCESSING_STATS = "processing_stats"
    """Streamer that provides the processing stats of the records."""

    BASIC_METRICS = "basic_metrics"
    """Streamer that handles the basic metrics of the records."""

    JSONL = "jsonl"
    """Streams all parsed records to a JSONL file."""

BASIC_METRICS = 'basic_metrics' class-attribute instance-attribute

Streamer that handles the basic metrics of the records.

JSONL = 'jsonl' class-attribute instance-attribute

Streams all parsed records to a JSONL file.

PROCESSING_STATS = 'processing_stats' class-attribute instance-attribute

Streamer that provides the processing stats of the records.

aiperf.common.enums.service_enums

ServiceRegistrationStatus

Bases: CaseInsensitiveStrEnum

Defines the various states a service can be in during registration with the SystemController.

Source code in aiperf/common/enums/service_enums.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class ServiceRegistrationStatus(CaseInsensitiveStrEnum):
    """Defines the various states a service can be in during registration with
    the SystemController."""

    UNREGISTERED = "unregistered"
    """The service is not registered with the SystemController. This is the
    initial state."""

    WAITING = "waiting"
    """The service is waiting for the SystemController to register it.
    This is a temporary state that should be followed by REGISTERED, TIMEOUT, or ERROR."""

    REGISTERED = "registered"
    """The service is registered with the SystemController."""

    TIMEOUT = "timeout"
    """The service registration timed out."""

    ERROR = "error"
    """The service registration failed."""

ERROR = 'error' class-attribute instance-attribute

The service registration failed.

REGISTERED = 'registered' class-attribute instance-attribute

The service is registered with the SystemController.

TIMEOUT = 'timeout' class-attribute instance-attribute

The service registration timed out.

UNREGISTERED = 'unregistered' class-attribute instance-attribute

The service is not registered with the SystemController. This is the initial state.

WAITING = 'waiting' class-attribute instance-attribute

The service is waiting for the SystemController to register it. This is a temporary state that should be followed by REGISTERED, TIMEOUT, or ERROR.

ServiceRunType

Bases: CaseInsensitiveStrEnum

The different ways the SystemController should run the component services.

Source code in aiperf/common/enums/service_enums.py
 7
 8
 9
10
11
12
13
14
15
16
class ServiceRunType(CaseInsensitiveStrEnum):
    """The different ways the SystemController should run the component services."""

    MULTIPROCESSING = "process"
    """Run each service as a separate process.
    This is the default way for single-node deployments."""

    KUBERNETES = "k8s"
    """Run each service as a separate Kubernetes pod.
    This is the default way for multi-node deployments."""

KUBERNETES = 'k8s' class-attribute instance-attribute

Run each service as a separate Kubernetes pod. This is the default way for multi-node deployments.

MULTIPROCESSING = 'process' class-attribute instance-attribute

Run each service as a separate process. This is the default way for single-node deployments.

ServiceState

Bases: CaseInsensitiveStrEnum

States a service can be in throughout its lifecycle.

Source code in aiperf/common/enums/service_enums.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class ServiceState(CaseInsensitiveStrEnum):
    """States a service can be in throughout its lifecycle."""

    UNKNOWN = "unknown"
    """The service is in an unknown state."""

    INITIALIZING = "initializing"
    """The service is currently initializing. This is a temporary state that should be
    followed by PENDING."""

    PENDING = "pending"
    """The service is pending configuration."""

    CONFIGURING = "configuring"
    """The service is currently configuring. This is a temporary state that should be
    followed by READY."""

    READY = "ready"
    """The service has been configured and is ready to be started."""

    STARTING = "starting"
    """The service is starting. This is a temporary state that should be followed
    by RUNNING."""

    RUNNING = "running"
    """The service is running."""

    STOPPING = "stopping"
    """The service is stopping. This is a temporary state that should be followed
    by STOPPED."""

    STOPPED = "stopped"
    """The service has been stopped."""

    SHUTDOWN = "shutdown"
    """The service has been shutdown."""

    ERROR = "error"
    """The service is currently in an error state."""

CONFIGURING = 'configuring' class-attribute instance-attribute

The service is currently configuring. This is a temporary state that should be followed by READY.

ERROR = 'error' class-attribute instance-attribute

The service is currently in an error state.

INITIALIZING = 'initializing' class-attribute instance-attribute

The service is currently initializing. This is a temporary state that should be followed by PENDING.

PENDING = 'pending' class-attribute instance-attribute

The service is pending configuration.

READY = 'ready' class-attribute instance-attribute

The service has been configured and is ready to be started.

RUNNING = 'running' class-attribute instance-attribute

The service is running.

SHUTDOWN = 'shutdown' class-attribute instance-attribute

The service has been shutdown.

STARTING = 'starting' class-attribute instance-attribute

The service is starting. This is a temporary state that should be followed by RUNNING.

STOPPED = 'stopped' class-attribute instance-attribute

The service has been stopped.

STOPPING = 'stopping' class-attribute instance-attribute

The service is stopping. This is a temporary state that should be followed by STOPPED.

UNKNOWN = 'unknown' class-attribute instance-attribute

The service is in an unknown state.

ServiceType

Bases: CaseInsensitiveStrEnum

Types of services in the AIPerf system.

This is used to identify the service type when registering with the SystemController. It can also be used for tracking purposes if multiple instances of the same service type are running.

Source code in aiperf/common/enums/service_enums.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class ServiceType(CaseInsensitiveStrEnum):
    """Types of services in the AIPerf system.

    This is used to identify the service type when registering with the
    SystemController. It can also be used for tracking purposes if multiple
    instances of the same service type are running.
    """

    SYSTEM_CONTROLLER = "system_controller"
    DATASET_MANAGER = "dataset_manager"
    TIMING_MANAGER = "timing_manager"
    RECORDS_MANAGER = "records_manager"
    INFERENCE_RESULT_PARSER = "inference_result_parser"
    WORKER_MANAGER = "worker_manager"
    WORKER = "worker"

    # For testing purposes only
    TEST = "test_service"

aiperf.common.enums.sse_enums

SSEEventType

Bases: CaseInsensitiveStrEnum

Event types in an SSE message. Many of these are custom and not defined by the SSE spec.

Source code in aiperf/common/enums/sse_enums.py
17
18
19
20
21
class SSEEventType(CaseInsensitiveStrEnum):
    """Event types in an SSE message. Many of these are custom and not defined by the SSE spec."""

    ERROR = "error"
    LLM_METRICS = "llm_metrics"

SSEFieldType

Bases: CaseInsensitiveStrEnum

Field types in an SSE message.

Source code in aiperf/common/enums/sse_enums.py
 7
 8
 9
10
11
12
13
14
class SSEFieldType(CaseInsensitiveStrEnum):
    """Field types in an SSE message."""

    DATA = "data"
    EVENT = "event"
    ID = "id"
    RETRY = "retry"
    COMMENT = "comment"

aiperf.common.enums.system_enums

SystemState

Bases: CaseInsensitiveStrEnum

State of the system as a whole.

This is used to track the state of the system as a whole, and is used to determine what actions to take when a signal is received.

Source code in aiperf/common/enums/system_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class SystemState(CaseInsensitiveStrEnum):
    """State of the system as a whole.

    This is used to track the state of the system as a whole, and is used to
    determine what actions to take when a signal is received.
    """

    INITIALIZING = "initializing"
    """The system is initializing. This is the initial state."""

    CONFIGURING = "configuring"
    """The system is configuring services."""

    READY = "ready"
    """The system is ready to start profiling. This is a temporary state that should be
    followed by PROFILING."""

    PROFILING = "profiling"
    """The system is running a profiling run."""

    PROCESSING = "processing"
    """The system is processing results."""

    STOPPING = "stopping"
    """The system is stopping."""

    SHUTDOWN = "shutdown"
    """The system is shutting down. This is the final state."""

CONFIGURING = 'configuring' class-attribute instance-attribute

The system is configuring services.

INITIALIZING = 'initializing' class-attribute instance-attribute

The system is initializing. This is the initial state.

PROCESSING = 'processing' class-attribute instance-attribute

The system is processing results.

PROFILING = 'profiling' class-attribute instance-attribute

The system is running a profiling run.

READY = 'ready' class-attribute instance-attribute

The system is ready to start profiling. This is a temporary state that should be followed by PROFILING.

SHUTDOWN = 'shutdown' class-attribute instance-attribute

The system is shutting down. This is the final state.

STOPPING = 'stopping' class-attribute instance-attribute

The system is stopping.

aiperf.common.enums.timing_enums

CreditPhase

Bases: CaseInsensitiveStrEnum

The type of credit phase. This is used to identify which phase of the benchmark the credit is being used in, for tracking and reporting purposes.

Source code in aiperf/common/enums/timing_enums.py
30
31
32
33
34
35
36
37
38
39
40
class CreditPhase(CaseInsensitiveStrEnum):
    """The type of credit phase. This is used to identify which phase of the
    benchmark the credit is being used in, for tracking and reporting purposes."""

    WARMUP = "warmup"
    """The credit phase is the warmup phase. This is used to warm up the model
    before the benchmark starts."""

    PROFILING = "profiling"
    """The credit phase is the steady state phase. This is the primary phase of the
    benchmark, and what is used to calculate the final results."""

PROFILING = 'profiling' class-attribute instance-attribute

The credit phase is the steady state phase. This is the primary phase of the benchmark, and what is used to calculate the final results.

WARMUP = 'warmup' class-attribute instance-attribute

The credit phase is the warmup phase. This is used to warm up the model before the benchmark starts.

RequestRateMode

Bases: CaseInsensitiveStrEnum

The different ways the RequestRateStrategy should generate requests.

Source code in aiperf/common/enums/timing_enums.py
20
21
22
23
24
25
26
27
class RequestRateMode(CaseInsensitiveStrEnum):
    """The different ways the RequestRateStrategy should generate requests."""

    CONSTANT = "constant"
    """Generate requests at a constant rate."""

    POISSON = "poisson"
    """Generate requests using a poisson distribution."""

CONSTANT = 'constant' class-attribute instance-attribute

Generate requests at a constant rate.

POISSON = 'poisson' class-attribute instance-attribute

Generate requests using a poisson distribution.

TimingMode

Bases: CaseInsensitiveStrEnum

The different ways the TimingManager should generate requests.

Source code in aiperf/common/enums/timing_enums.py
 7
 8
 9
10
11
12
13
14
15
16
17
class TimingMode(CaseInsensitiveStrEnum):
    """The different ways the TimingManager should generate requests."""

    FIXED_SCHEDULE = "fixed_schedule"
    """A mode where the TimingManager will send requests according to a fixed schedule."""

    CONCURRENCY = "concurrency"
    """A mode where the TimingManager will maintain a continuous stream of concurrent requests."""

    REQUEST_RATE = "request_rate"
    """A mode where the TimingManager will send requests at either a constant request rate or based on a poisson distribution."""

CONCURRENCY = 'concurrency' class-attribute instance-attribute

A mode where the TimingManager will maintain a continuous stream of concurrent requests.

FIXED_SCHEDULE = 'fixed_schedule' class-attribute instance-attribute

A mode where the TimingManager will send requests according to a fixed schedule.

REQUEST_RATE = 'request_rate' class-attribute instance-attribute

A mode where the TimingManager will send requests at either a constant request rate or based on a poisson distribution.

aiperf.common.exceptions

AIPerfError

Bases: Exception

Base class for all exceptions raised by AIPerf.

Source code in aiperf/common/exceptions.py
 7
 8
 9
10
11
12
13
14
15
16
class AIPerfError(Exception):
    """Base class for all exceptions raised by AIPerf."""

    def raw_str(self) -> str:
        """Return the raw string representation of the exception."""
        return super().__str__()

    def __str__(self) -> str:
        """Return the string representation of the exception with the class name."""
        return f"{self.__class__.__name__}: {super().__str__()}"

__str__()

Return the string representation of the exception with the class name.

Source code in aiperf/common/exceptions.py
14
15
16
def __str__(self) -> str:
    """Return the string representation of the exception with the class name."""
    return f"{self.__class__.__name__}: {super().__str__()}"

raw_str()

Return the raw string representation of the exception.

Source code in aiperf/common/exceptions.py
10
11
12
def raw_str(self) -> str:
    """Return the raw string representation of the exception."""
    return super().__str__()

AIPerfMultiError

Bases: AIPerfError

Exception raised when running multiple tasks and one or more fail.

Source code in aiperf/common/exceptions.py
19
20
21
22
23
24
25
26
27
class AIPerfMultiError(AIPerfError):
    """Exception raised when running multiple tasks and one or more fail."""

    def __init__(self, message: str, exceptions: list[Exception]) -> None:
        err_strings = [
            e.raw_str() if isinstance(e, AIPerfError) else str(e) for e in exceptions
        ]
        super().__init__(f"{message}: {','.join(err_strings)}")
        self.exceptions = exceptions

CommunicationError

Bases: AIPerfError

Generic communication error.

Source code in aiperf/common/exceptions.py
70
71
class CommunicationError(AIPerfError):
    """Generic communication error."""

ConfigurationError

Bases: AIPerfError

Exception raised when something fails to configure, or there is a configuration error.

Source code in aiperf/common/exceptions.py
50
51
class ConfigurationError(AIPerfError):
    """Exception raised when something fails to configure, or there is a configuration error."""

DatasetError

Bases: AIPerfError

Generic dataset error.

Source code in aiperf/common/exceptions.py
74
75
class DatasetError(AIPerfError):
    """Generic dataset error."""

DatasetGeneratorError

Bases: AIPerfError

Generic dataset generator error.

Source code in aiperf/common/exceptions.py
78
79
class DatasetGeneratorError(AIPerfError):
    """Generic dataset generator error."""

FactoryCreationError

Bases: AIPerfError

Exception raised when a factory encounters an error while creating a class.

Source code in aiperf/common/exceptions.py
94
95
class FactoryCreationError(AIPerfError):
    """Exception raised when a factory encounters an error while creating a class."""

InferenceClientError

Bases: AIPerfError

Exception raised when a inference client encounters an error.

Source code in aiperf/common/exceptions.py
82
83
class InferenceClientError(AIPerfError):
    """Exception raised when a inference client encounters an error."""

InitializationError

Bases: AIPerfError

Exception raised when something fails to initialize.

Source code in aiperf/common/exceptions.py
46
47
class InitializationError(AIPerfError):
    """Exception raised when something fails to initialize."""

InvalidPayloadError

Bases: InferenceClientError

Exception raised when a inference client receives an invalid payload.

Source code in aiperf/common/exceptions.py
86
87
class InvalidPayloadError(InferenceClientError):
    """Exception raised when a inference client receives an invalid payload."""

InvalidStateError

Bases: AIPerfError

Exception raised when something is in an invalid state.

Source code in aiperf/common/exceptions.py
58
59
class InvalidStateError(AIPerfError):
    """Exception raised when something is in an invalid state."""

MetricTypeError

Bases: AIPerfError

Exception raised when a metric type encounters an error while creating a class.

Source code in aiperf/common/exceptions.py
98
99
class MetricTypeError(AIPerfError):
    """Exception raised when a metric type encounters an error while creating a class."""

NotFoundError

Bases: AIPerfError

Exception raised when something is not found or not available.

Source code in aiperf/common/exceptions.py
66
67
class NotFoundError(AIPerfError):
    """Exception raised when something is not found or not available."""

NotInitializedError

Bases: AIPerfError

Exception raised when something that should be initialized is not.

Source code in aiperf/common/exceptions.py
54
55
class NotInitializedError(AIPerfError):
    """Exception raised when something that should be initialized is not."""

ProxyError

Bases: AIPerfError

Exception raised when a proxy encounters an error.

Source code in aiperf/common/exceptions.py
106
107
class ProxyError(AIPerfError):
    """Exception raised when a proxy encounters an error."""

ServiceError

Bases: AIPerfError

Generic service error.

Source code in aiperf/common/exceptions.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ServiceError(AIPerfError):
    """Generic service error."""

    def __init__(
        self,
        message: str,
        service_type: ServiceType,
        service_id: str,
    ) -> None:
        super().__init__(
            f"{message} for service of type {service_type} with id {service_id}"
        )
        self.service_type = service_type
        self.service_id = service_id

ShutdownError

Bases: AIPerfError

Exception raised when a service encounters an error while shutting down.

Source code in aiperf/common/exceptions.py
102
103
class ShutdownError(AIPerfError):
    """Exception raised when a service encounters an error while shutting down."""

UnsupportedHookError

Bases: AIPerfError

Exception raised when a hook is defined on a class that does not support it.

Source code in aiperf/common/exceptions.py
90
91
class UnsupportedHookError(AIPerfError):
    """Exception raised when a hook is defined on a class that does not support it."""

ValidationError

Bases: AIPerfError

Exception raised when something fails validation.

Source code in aiperf/common/exceptions.py
62
63
class ValidationError(AIPerfError):
    """Exception raised when something fails validation."""

aiperf.common.factories

ComposerFactory

Bases: FactoryMixin['ComposerType', 'BaseDatasetComposer']

Factory for registering and creating BaseDatasetComposer instances based on the specified composer type.

Example:

    # Register a new composer type
    @ComposerFactory.register(ComposerType.SYNTHETIC)
    class SyntheticDatasetComposer(BaseDatasetComposer):
        pass

    # Create a new composer instance
    composer = ComposerFactory.create_instance(
        ComposerType.SYNTHETIC,
        config=InputConfig(
            conversation=ConversationConfig(num=10),
            prompt=PromptConfig(batch_size=10),
        )
    )
Source code in aiperf/common/factories.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class ComposerFactory(FactoryMixin["ComposerType", "BaseDatasetComposer"]):
    """Factory for registering and creating BaseDatasetComposer instances
    based on the specified composer type.

    Example:
    ```python
        # Register a new composer type
        @ComposerFactory.register(ComposerType.SYNTHETIC)
        class SyntheticDatasetComposer(BaseDatasetComposer):
            pass

        # Create a new composer instance
        composer = ComposerFactory.create_instance(
            ComposerType.SYNTHETIC,
            config=InputConfig(
                conversation=ConversationConfig(num=10),
                prompt=PromptConfig(batch_size=10),
            )
        )
    ```
    """

CustomDatasetFactory

Bases: FactoryMixin['CustomDatasetType', 'CustomDatasetLoaderProtocol']

Factory for registering and creating CustomDatasetLoader instances based on the specified custom dataset type.

Example:

    # Register a new custom dataset type
    @CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE)
    class MooncakeTraceDatasetLoader(CustomDatasetLoader):
        pass

    # Create a new custom dataset loader instance
    custom_dataset_loader = CustomDatasetFactory.create_instance(
        CustomDatasetType.MOONCAKE_TRACE, **kwargs
    )
Source code in aiperf/common/factories.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class CustomDatasetFactory(
    FactoryMixin["CustomDatasetType", "CustomDatasetLoaderProtocol"]
):
    """
    Factory for registering and creating CustomDatasetLoader instances
    based on the specified custom dataset type.

    Example:
    ```python
        # Register a new custom dataset type
        @CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE)
        class MooncakeTraceDatasetLoader(CustomDatasetLoader):
            pass

        # Create a new custom dataset loader instance
        custom_dataset_loader = CustomDatasetFactory.create_instance(
            CustomDatasetType.MOONCAKE_TRACE, **kwargs
        )
    ```
    """

DataExporterFactory

Bases: FactoryMixin['DataExporterType', 'DataExporterProtocol']

Factory for registering and creating DataExporterInterface instances.

Example:

    # Iterate over all registered data exporter types
    for exporter_class in DataExporterFactory.get_all_classes():
        exporter = exporter_class(endpoint_config)

        exporter.export()
Source code in aiperf/common/factories.py
244
245
246
247
248
249
250
251
252
253
254
255
class DataExporterFactory(FactoryMixin["DataExporterType", "DataExporterProtocol"]):
    """Factory for registering and creating DataExporterInterface instances.

    Example:
    ```python
        # Iterate over all registered data exporter types
        for exporter_class in DataExporterFactory.get_all_classes():
            exporter = exporter_class(endpoint_config)

            exporter.export()
    ```
    """

FactoryMixin

Bases: Generic[ClassEnumT, ClassProtocolT]

Defines a mixin for all factories, which supports registering and creating instances of classes.

This mixin is used to create a factory for a given class type and protocol.

Example:

    # Define a new enum for the expected implementation types
    # This is optional, but recommended for type safety.
    class DatasetLoaderType(CaseInsensitiveStrEnum):
        FILE = "file"
        S3 = "s3"

    # Define a new class protocol.
    class DatasetLoaderProtocol(Protocol):
        def load(self) -> Dataset:
            pass

    # Create a new factory for a given class type and protocol.
    class DatasetFactory(FactoryMixin[DatasetLoaderType, DatasetLoaderProtocol]):
        pass

    # Register a new class type mapping to its corresponding class. It should implement the class protocol.
    @DatasetFactory.register(DatasetLoaderType.FILE)
    class FileDatasetLoader:
        def __init__(self, filename: str):
            self.filename = filename

        def load(self) -> Dataset:
            return Dataset.from_file(self.filename)

    DatasetConfig = {
        "type": DatasetLoaderType.FILE,
        "filename": "data.csv"
    }

    # Create a new instance of the class.
    if DatasetConfig["type"] == DatasetLoaderType.FILE:
        dataset_instance = DatasetFactory.create_instance(DatasetLoaderType.FILE, filename=DatasetConfig["filename"])
    else:
        raise ValueError(f"Unsupported dataset loader type: {DatasetConfig['type']}")

    dataset_instance.load()
Source code in aiperf/common/factories.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class FactoryMixin(Generic[ClassEnumT, ClassProtocolT]):
    """Defines a mixin for all factories, which supports registering and creating instances of classes.

    This mixin is used to create a factory for a given class type and protocol.

    Example:
    ```python
        # Define a new enum for the expected implementation types
        # This is optional, but recommended for type safety.
        class DatasetLoaderType(CaseInsensitiveStrEnum):
            FILE = "file"
            S3 = "s3"

        # Define a new class protocol.
        class DatasetLoaderProtocol(Protocol):
            def load(self) -> Dataset:
                pass

        # Create a new factory for a given class type and protocol.
        class DatasetFactory(FactoryMixin[DatasetLoaderType, DatasetLoaderProtocol]):
            pass

        # Register a new class type mapping to its corresponding class. It should implement the class protocol.
        @DatasetFactory.register(DatasetLoaderType.FILE)
        class FileDatasetLoader:
            def __init__(self, filename: str):
                self.filename = filename

            def load(self) -> Dataset:
                return Dataset.from_file(self.filename)

        DatasetConfig = {
            "type": DatasetLoaderType.FILE,
            "filename": "data.csv"
        }

        # Create a new instance of the class.
        if DatasetConfig["type"] == DatasetLoaderType.FILE:
            dataset_instance = DatasetFactory.create_instance(DatasetLoaderType.FILE, filename=DatasetConfig["filename"])
        else:
            raise ValueError(f"Unsupported dataset loader type: {DatasetConfig['type']}")

        dataset_instance.load()
    ```
    """

    logger = logging.getLogger(__name__)

    _registry: dict[ClassEnumT | str, type[ClassProtocolT]]
    _override_priorities: dict[ClassEnumT | str, int]

    def __init_subclass__(cls) -> None:
        cls._registry = {}
        cls._override_priorities = {}
        cls.logger = logging.getLogger(cls.__name__)
        super().__init_subclass__()

    @classmethod
    def register_all(
        cls, *class_types: ClassEnumT | str, override_priority: int = 0
    ) -> Callable:
        """Register multiple class types mapping to a single corresponding class.
        This is useful if a single class implements multiple types. Currently only supports
        registering as a single override priority for all types."""

        def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
            for class_type in class_types:
                cls.register(class_type, override_priority)(class_cls)
            return class_cls

        return decorator

    @classmethod
    def register(
        cls, class_type: ClassEnumT | str, override_priority: int = 0
    ) -> Callable:
        """Register a new class type mapping to its corresponding class.

        Args:
            class_type: The type of class to register
            override_priority: The priority of the override. The higher the priority,
                the more precedence the override has when multiple classes are registered
                for the same class type. Built-in classes have a priority of 0.

        Returns:
            Decorator for the class that implements the class protocol
        """

        def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
            existing_priority = cls._override_priorities.get(class_type, -1)
            if class_type in cls._registry and existing_priority >= override_priority:
                cls.logger.warning(
                    "%r class %s already registered with same or higher priority "
                    "(%s). The new registration of class %s with priority "
                    "%s will be ignored.",
                    class_type,
                    cls._registry[class_type].__name__,
                    existing_priority,
                    class_cls.__name__,
                    override_priority,
                )
                return class_cls

            if class_type not in cls._registry:
                cls.logger.debug(
                    "%r class %s registered with priority %s.",
                    class_type,
                    class_cls.__name__,
                    override_priority,
                )
            else:
                cls.logger.warning(
                    "%r class %s with priority %s overrides "
                    "already registered class %s with lower priority (%s).",
                    class_type,
                    class_cls.__name__,
                    override_priority,
                    cls._registry[class_type].__name__,
                    existing_priority,
                )
            cls._registry[class_type] = class_cls
            cls._override_priorities[class_type] = override_priority
            return class_cls

        return decorator

    @classmethod
    def create_instance(
        cls,
        class_type: ClassEnumT | str,
        **kwargs: Any,
    ) -> ClassProtocolT:
        """Create a new class instance.

        Args:
            class_type: The type of class to create
            **kwargs: Additional arguments for the class

        Returns:
            The created class instance

        Raises:
            FactoryCreationError: If the class type is not registered or there is an error creating the instance
        """
        if class_type not in cls._registry:
            raise FactoryCreationError(f"No implementation found for {class_type!r}.")
        try:
            return cls._registry[class_type](**kwargs)
        except Exception as e:
            raise FactoryCreationError(
                f"Error creating {class_type!r} instance: {e}"
            ) from e

    @classmethod
    def get_class_from_type(cls, class_type: ClassEnumT | str) -> type[ClassProtocolT]:
        """Get the class from a class type.

        Args:
            class_type: The class type to get the class from

        Returns:
            The class for the given class type

        Raises:
            TypeError: If the class type is not registered
        """
        if class_type not in cls._registry:
            raise TypeError(
                f"No class found for {class_type!r}. Please register the class first."
            )
        return cls._registry[class_type]

    @classmethod
    def get_all_classes(cls) -> list[type[ClassProtocolT]]:
        """Get all registered classes.

        Returns:
            A list of all registered class types implementing the expected protocol
        """
        return list(cls._registry.values())

    @classmethod
    def get_all_class_types(cls) -> list[ClassEnumT | str]:
        """Get all registered class types."""
        return list(cls._registry.keys())

    @classmethod
    def get_all_classes_and_types(
        cls,
    ) -> list[tuple[type[ClassProtocolT], ClassEnumT | str]]:
        """Get all registered classes and their corresponding class types."""
        return [(cls, class_type) for class_type, cls in cls._registry.items()]

create_instance(class_type, **kwargs) classmethod

Create a new class instance.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The type of class to create

required
**kwargs Any

Additional arguments for the class

{}

Returns:

Type Description
ClassProtocolT

The created class instance

Raises:

Type Description
FactoryCreationError

If the class type is not registered or there is an error creating the instance

Source code in aiperf/common/factories.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def create_instance(
    cls,
    class_type: ClassEnumT | str,
    **kwargs: Any,
) -> ClassProtocolT:
    """Create a new class instance.

    Args:
        class_type: The type of class to create
        **kwargs: Additional arguments for the class

    Returns:
        The created class instance

    Raises:
        FactoryCreationError: If the class type is not registered or there is an error creating the instance
    """
    if class_type not in cls._registry:
        raise FactoryCreationError(f"No implementation found for {class_type!r}.")
    try:
        return cls._registry[class_type](**kwargs)
    except Exception as e:
        raise FactoryCreationError(
            f"Error creating {class_type!r} instance: {e}"
        ) from e

get_all_class_types() classmethod

Get all registered class types.

Source code in aiperf/common/factories.py
203
204
205
206
@classmethod
def get_all_class_types(cls) -> list[ClassEnumT | str]:
    """Get all registered class types."""
    return list(cls._registry.keys())

get_all_classes() classmethod

Get all registered classes.

Returns:

Type Description
list[type[ClassProtocolT]]

A list of all registered class types implementing the expected protocol

Source code in aiperf/common/factories.py
194
195
196
197
198
199
200
201
@classmethod
def get_all_classes(cls) -> list[type[ClassProtocolT]]:
    """Get all registered classes.

    Returns:
        A list of all registered class types implementing the expected protocol
    """
    return list(cls._registry.values())

get_all_classes_and_types() classmethod

Get all registered classes and their corresponding class types.

Source code in aiperf/common/factories.py
208
209
210
211
212
213
@classmethod
def get_all_classes_and_types(
    cls,
) -> list[tuple[type[ClassProtocolT], ClassEnumT | str]]:
    """Get all registered classes and their corresponding class types."""
    return [(cls, class_type) for class_type, cls in cls._registry.items()]

get_class_from_type(class_type) classmethod

Get the class from a class type.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The class type to get the class from

required

Returns:

Type Description
type[ClassProtocolT]

The class for the given class type

Raises:

Type Description
TypeError

If the class type is not registered

Source code in aiperf/common/factories.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@classmethod
def get_class_from_type(cls, class_type: ClassEnumT | str) -> type[ClassProtocolT]:
    """Get the class from a class type.

    Args:
        class_type: The class type to get the class from

    Returns:
        The class for the given class type

    Raises:
        TypeError: If the class type is not registered
    """
    if class_type not in cls._registry:
        raise TypeError(
            f"No class found for {class_type!r}. Please register the class first."
        )
    return cls._registry[class_type]

register(class_type, override_priority=0) classmethod

Register a new class type mapping to its corresponding class.

Parameters:

Name Type Description Default
class_type ClassEnumT | str

The type of class to register

required
override_priority int

The priority of the override. The higher the priority, the more precedence the override has when multiple classes are registered for the same class type. Built-in classes have a priority of 0.

0

Returns:

Type Description
Callable

Decorator for the class that implements the class protocol

Source code in aiperf/common/factories.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@classmethod
def register(
    cls, class_type: ClassEnumT | str, override_priority: int = 0
) -> Callable:
    """Register a new class type mapping to its corresponding class.

    Args:
        class_type: The type of class to register
        override_priority: The priority of the override. The higher the priority,
            the more precedence the override has when multiple classes are registered
            for the same class type. Built-in classes have a priority of 0.

    Returns:
        Decorator for the class that implements the class protocol
    """

    def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
        existing_priority = cls._override_priorities.get(class_type, -1)
        if class_type in cls._registry and existing_priority >= override_priority:
            cls.logger.warning(
                "%r class %s already registered with same or higher priority "
                "(%s). The new registration of class %s with priority "
                "%s will be ignored.",
                class_type,
                cls._registry[class_type].__name__,
                existing_priority,
                class_cls.__name__,
                override_priority,
            )
            return class_cls

        if class_type not in cls._registry:
            cls.logger.debug(
                "%r class %s registered with priority %s.",
                class_type,
                class_cls.__name__,
                override_priority,
            )
        else:
            cls.logger.warning(
                "%r class %s with priority %s overrides "
                "already registered class %s with lower priority (%s).",
                class_type,
                class_cls.__name__,
                override_priority,
                cls._registry[class_type].__name__,
                existing_priority,
            )
        cls._registry[class_type] = class_cls
        cls._override_priorities[class_type] = override_priority
        return class_cls

    return decorator

register_all(*class_types, override_priority=0) classmethod

Register multiple class types mapping to a single corresponding class. This is useful if a single class implements multiple types. Currently only supports registering as a single override priority for all types.

Source code in aiperf/common/factories.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@classmethod
def register_all(
    cls, *class_types: ClassEnumT | str, override_priority: int = 0
) -> Callable:
    """Register multiple class types mapping to a single corresponding class.
    This is useful if a single class implements multiple types. Currently only supports
    registering as a single override priority for all types."""

    def decorator(class_cls: type[ClassProtocolT]) -> type[ClassProtocolT]:
        for class_type in class_types:
            cls.register(class_type, override_priority)(class_cls)
        return class_cls

    return decorator

PostProcessorFactory

Bases: FactoryMixin['PostProcessorType', 'PostProcessorProtocol']

Factory for registering and creating PostProcessor instances based on the specified post-processor type.

Example: ```python # Register a new post-processor type @PostProcessorFactory.register(PostProcessorType.METRIC_SUMMARY) class MetricSummary: pass

# Create a new post-processor instance
post_processor = PostProcessorFactory.create_instance(
    PostProcessorType.METRIC_SUMMARY,
)
Source code in aiperf/common/factories.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class PostProcessorFactory(FactoryMixin["PostProcessorType", "PostProcessorProtocol"]):
    """Factory for registering and creating PostProcessor instances based on the specified post-processor type.

    Example:
    ```python
        # Register a new post-processor type
        @PostProcessorFactory.register(PostProcessorType.METRIC_SUMMARY)
        class MetricSummary:
            pass

        # Create a new post-processor instance
        post_processor = PostProcessorFactory.create_instance(
            PostProcessorType.METRIC_SUMMARY,
        )
    """

ServiceFactory

Bases: FactoryMixin[ServiceType, 'BaseService']

Factory for registering and creating BaseService instances based on the specified service type.

Example:

    # Register a new service type
    @ServiceFactory.register(ServiceType.DATASET_MANAGER)
    class DatasetManager(BaseService):
        pass

    # Create a new service instance in a separate process
    service_class = ServiceFactory.get_class_from_type(service_type)

    process = Process(
        target=bootstrap_and_run_service,
        name=f"{service_type}_process",
        args=(service_class, self.config),
        daemon=False,
    )
Source code in aiperf/common/factories.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class ServiceFactory(FactoryMixin[ServiceType, "BaseService"]):
    """Factory for registering and creating BaseService instances based on the specified service type.

    Example:
    ```python
        # Register a new service type
        @ServiceFactory.register(ServiceType.DATASET_MANAGER)
        class DatasetManager(BaseService):
            pass

        # Create a new service instance in a separate process
        service_class = ServiceFactory.get_class_from_type(service_type)

        process = Process(
            target=bootstrap_and_run_service,
            name=f"{service_type}_process",
            args=(service_class, self.config),
            daemon=False,
        )
    ```
    """

StreamingPostProcessorFactory

Bases: FactoryMixin[StreamingPostProcessorType, 'StreamingPostProcessor']

Factory for creating StreamingPostProcessor instances. see: :class:FactoryMixin for more details.

Source code in aiperf/common/factories.py
320
321
322
323
324
325
class StreamingPostProcessorFactory(
    FactoryMixin[StreamingPostProcessorType, "StreamingPostProcessor"]
):
    """Factory for creating StreamingPostProcessor instances.
    see: :class:`FactoryMixin` for more details.
    """

aiperf.common.hooks

This module provides an extensive hook system for AIPerf. It is designed to be used as a mixin for classes that support hooks. It provides a simple interface for registering and running hooks.

Classes should inherit from the :class:HooksMixin, and specify the supported hook types by decorating the class with the :func:supports_hooks decorator.

The hook functions are registered by decorating functions with the various hook decorators such as :func:on_init, :func:on_start, :func:on_stop, etc.

The hooks are run by calling the :meth:HooksMixin.run_hooks or :meth:HooksMixin.run_hooks_async methods on the class.

More than one hook can be registered for a given hook type, and classes that inherit from classes with existing hooks will inherit the hooks from the base classes as well.

AIPERF_HOOK_TYPE = '__aiperf_hook_type__' module-attribute

Constant attribute name that marks a function's hook type.

HookType = AIPerfHook | AIPerfTaskHook | str module-attribute

Type alias for valid hook types. This is a union of the AIPerfHook enum, the AIPerfTaskHook enum, and any user-defined custom strings.

AIPerfHook

Bases: CaseInsensitiveStrEnum

Enum for the various AIPerf hooks.

Note: If you add a new hook, you must also add it to the @supports_hooks decorator of the class you wish to use the hook in.

Source code in aiperf/common/hooks.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class AIPerfHook(CaseInsensitiveStrEnum):
    """Enum for the various AIPerf hooks.

    Note: If you add a new hook, you must also add it to the @supports_hooks
    decorator of the class you wish to use the hook in.
    """

    ON_INIT = "__aiperf_on_init__"
    ON_RUN = "__aiperf_on_run__"
    ON_CONFIGURE = "__aiperf_on_configure__"
    ON_PROFILE_CONFIGURE = "__aiperf_on_profile_configure__"
    ON_PROFILE_START = "__aiperf_on_profile_start__"
    ON_PROFILE_STOP = "__aiperf_on_profile_stop__"
    ON_START = "__aiperf_on_start__"
    ON_STOP = "__aiperf_on_stop__"
    ON_CLEANUP = "__aiperf_on_cleanup__"

    ON_SET_STATE = "__aiperf_on_set_state__"

AIPerfTaskHook

Bases: CaseInsensitiveStrEnum

Enum for the various AIPerf task hooks.

Source code in aiperf/common/hooks.py
58
59
60
61
62
63
class AIPerfTaskHook(CaseInsensitiveStrEnum):
    """Enum for the various AIPerf task hooks."""

    AIPERF_TASK = "__aiperf_task__"
    AIPERF_AUTO_TASK = "__aiperf_auto_task__"
    AIPERF_AUTO_TASK_INTERVAL = "__aiperf_auto_task_interval__"

HookSystem

System for managing hooks.

This class is responsible for managing the hooks for a class. It will store the hooks in a dictionary, and provide methods to register and run the hooks.

Source code in aiperf/common/hooks.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class HookSystem:
    """
    System for managing hooks.

    This class is responsible for managing the hooks for a class. It will
    store the hooks in a dictionary, and provide methods to register and run
    the hooks.
    """

    def __init__(self, supported_hooks: set[HookType]):
        """
        Initialize the hook system.

        Args:
            supported_hooks: The hook types that the class supports.
        """
        self.logger = logging.getLogger(__class__.__name__)
        self.supported_hooks: set[HookType] = supported_hooks
        self._hooks: dict[HookType, list[Callable]] = {}

    def register_hook(self, hook_type: HookType, func: Callable):
        """Register a hook function for a given hook type.

        Args:
            hook_type: The hook type to register the function for.
            func: The function to register.
        """
        if hook_type not in self.supported_hooks:
            raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

        self._hooks.setdefault(hook_type, []).append(func)

    def get_hooks(self, hook_type: HookType) -> list[Callable]:
        """Get all the registered hooks for the given hook type.

        Args:
            hook_type: The hook type to get the hooks for.

        Returns:
            A list of the hooks for the given hook type.
        """
        return self._hooks.get(hook_type, [])

    async def run_hooks(self, hook_type: HookType, *args, **kwargs):
        """
        Run all the hooks for a given hook type serially. This will wait for each
        hook to complete before running the next one.

        Args:
            hook_type: The hook type to run.
            *args: The arguments to pass to the hooks.
            **kwargs: The keyword arguments to pass to the hooks.
        """
        if hook_type not in self.supported_hooks:
            raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

        exceptions: list[Exception] = []
        for func in self.get_hooks(hook_type):
            try:
                if inspect.iscoroutinefunction(func):
                    await func(*args, **kwargs)
                else:
                    await asyncio.to_thread(func, *args, **kwargs)
            except Exception as e:
                self.logger.exception("Error running hook %s: %s", func.__qualname__, e)
                exceptions.append(
                    AIPerfError(
                        f"Error running hook {func.__qualname__}: {e.__class__.__name__} {e}"
                    )
                )

        if exceptions:
            raise AIPerfMultiError("Errors running hooks", exceptions)

    async def run_hooks_async(self, hook_type: HookType, *args, **kwargs):
        """
        Run all the hooks for a given hook type concurrently. This will run all
        the hooks at the same time and return when all the hooks have completed.

        Args:
            hook_type: The hook type to run.
            *args: The arguments to pass to the hooks.
            **kwargs: The keyword arguments to pass to the hooks.
        """
        if hook_type not in self.supported_hooks:
            raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

        coroutines: list[Awaitable] = []
        for func in self.get_hooks(hook_type):
            if inspect.iscoroutinefunction(func):
                coroutines.append(func(*args, **kwargs))
            else:
                coroutines.append(asyncio.to_thread(func, *args, **kwargs))

        if coroutines:
            results = await asyncio.gather(*coroutines, return_exceptions=True)

            exceptions = [result for result in results if isinstance(result, Exception)]
            if exceptions:
                raise AIPerfMultiError("Errors running hooks", exceptions)

__init__(supported_hooks)

Initialize the hook system.

Parameters:

Name Type Description Default
supported_hooks set[HookType]

The hook types that the class supports.

required
Source code in aiperf/common/hooks.py
88
89
90
91
92
93
94
95
96
97
def __init__(self, supported_hooks: set[HookType]):
    """
    Initialize the hook system.

    Args:
        supported_hooks: The hook types that the class supports.
    """
    self.logger = logging.getLogger(__class__.__name__)
    self.supported_hooks: set[HookType] = supported_hooks
    self._hooks: dict[HookType, list[Callable]] = {}

get_hooks(hook_type)

Get all the registered hooks for the given hook type.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to get the hooks for.

required

Returns:

Type Description
list[Callable]

A list of the hooks for the given hook type.

Source code in aiperf/common/hooks.py
111
112
113
114
115
116
117
118
119
120
def get_hooks(self, hook_type: HookType) -> list[Callable]:
    """Get all the registered hooks for the given hook type.

    Args:
        hook_type: The hook type to get the hooks for.

    Returns:
        A list of the hooks for the given hook type.
    """
    return self._hooks.get(hook_type, [])

register_hook(hook_type, func)

Register a hook function for a given hook type.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to register the function for.

required
func Callable

The function to register.

required
Source code in aiperf/common/hooks.py
 99
100
101
102
103
104
105
106
107
108
109
def register_hook(self, hook_type: HookType, func: Callable):
    """Register a hook function for a given hook type.

    Args:
        hook_type: The hook type to register the function for.
        func: The function to register.
    """
    if hook_type not in self.supported_hooks:
        raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

    self._hooks.setdefault(hook_type, []).append(func)

run_hooks(hook_type, *args, **kwargs) async

Run all the hooks for a given hook type serially. This will wait for each hook to complete before running the next one.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to run.

required
*args

The arguments to pass to the hooks.

()
**kwargs

The keyword arguments to pass to the hooks.

{}
Source code in aiperf/common/hooks.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
async def run_hooks(self, hook_type: HookType, *args, **kwargs):
    """
    Run all the hooks for a given hook type serially. This will wait for each
    hook to complete before running the next one.

    Args:
        hook_type: The hook type to run.
        *args: The arguments to pass to the hooks.
        **kwargs: The keyword arguments to pass to the hooks.
    """
    if hook_type not in self.supported_hooks:
        raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

    exceptions: list[Exception] = []
    for func in self.get_hooks(hook_type):
        try:
            if inspect.iscoroutinefunction(func):
                await func(*args, **kwargs)
            else:
                await asyncio.to_thread(func, *args, **kwargs)
        except Exception as e:
            self.logger.exception("Error running hook %s: %s", func.__qualname__, e)
            exceptions.append(
                AIPerfError(
                    f"Error running hook {func.__qualname__}: {e.__class__.__name__} {e}"
                )
            )

    if exceptions:
        raise AIPerfMultiError("Errors running hooks", exceptions)

run_hooks_async(hook_type, *args, **kwargs) async

Run all the hooks for a given hook type concurrently. This will run all the hooks at the same time and return when all the hooks have completed.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to run.

required
*args

The arguments to pass to the hooks.

()
**kwargs

The keyword arguments to pass to the hooks.

{}
Source code in aiperf/common/hooks.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
async def run_hooks_async(self, hook_type: HookType, *args, **kwargs):
    """
    Run all the hooks for a given hook type concurrently. This will run all
    the hooks at the same time and return when all the hooks have completed.

    Args:
        hook_type: The hook type to run.
        *args: The arguments to pass to the hooks.
        **kwargs: The keyword arguments to pass to the hooks.
    """
    if hook_type not in self.supported_hooks:
        raise UnsupportedHookError(f"Hook {hook_type} is not supported by class.")

    coroutines: list[Awaitable] = []
    for func in self.get_hooks(hook_type):
        if inspect.iscoroutinefunction(func):
            coroutines.append(func(*args, **kwargs))
        else:
            coroutines.append(asyncio.to_thread(func, *args, **kwargs))

    if coroutines:
        results = await asyncio.gather(*coroutines, return_exceptions=True)

        exceptions = [result for result in results if isinstance(result, Exception)]
        if exceptions:
            raise AIPerfMultiError("Errors running hooks", exceptions)

aiperf_auto_task(interval)

Decorator to indicate that the function is a task function. It will be started and stopped automatically by the base class lifecycle. See :func:aiperf.common.hooks.hook_decorator.

Parameters:

Name Type Description Default
interval float

The interval in seconds to sleep between runs.

required
Source code in aiperf/common/hooks.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def aiperf_auto_task(interval: float) -> Callable[[Callable], Callable]:
    """Decorator to indicate that the function is a task function. It will be started
    and stopped automatically by the base class lifecycle.
    See :func:`aiperf.common.hooks.hook_decorator`.

    Args:
        interval: The interval in seconds to sleep between runs.
    """

    def decorator(func: Callable) -> Callable:
        setattr(func, AIPerfTaskHook.AIPERF_AUTO_TASK_INTERVAL, interval)
        return hook_decorator(AIPerfTaskHook.AIPERF_AUTO_TASK, func)

    return decorator

aiperf_task(func)

Decorator to indicate that the function is a task function. It will be started and stopped automatically by the base class lifecycle. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
307
308
309
310
311
312
313
314
def aiperf_task(
    func: Callable,
) -> Callable:
    """Decorator to indicate that the function is a task function. It will be started
    and stopped automatically by the base class lifecycle.
    See :func:`aiperf.common.hooks.hook_decorator`.
    """
    return hook_decorator(AIPerfTaskHook.AIPERF_TASK, func)

hook_decorator(hook_type, func)

Generic decorator to specify that the function should be called during a specific hook.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to decorate the function with.

required
func Callable

The function to decorate.

required

Returns: The decorated function.

Source code in aiperf/common/hooks.py
226
227
228
229
230
231
232
233
234
235
236
237
def hook_decorator(hook_type: HookType, func: Callable) -> Callable:
    """Generic decorator to specify that the function should be called during
    a specific hook.

    Args:
        hook_type: The hook type to decorate the function with.
        func: The function to decorate.
    Returns:
        The decorated function.
    """
    setattr(func, AIPERF_HOOK_TYPE, hook_type)
    return func

on_cleanup(func)

Decorator to specify that the function should be called during cleanup. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
269
270
271
272
def on_cleanup(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during cleanup.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_CLEANUP, func)

on_configure(func)

Decorator to specify that the function should be called during the service configuration. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
263
264
265
266
def on_configure(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during the service configuration.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_CONFIGURE, func)

on_init(func)

Decorator to specify that the function should be called during initialization. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
245
246
247
248
def on_init(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during initialization.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_INIT, func)

on_profile_configure(func)

Decorator to specify that the function should be called during the service profile configuration. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
289
290
291
292
def on_profile_configure(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during the service profile configuration.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_PROFILE_CONFIGURE, func)

on_profile_start(func)

Decorator to specify that the function should be called during the service profile start. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
295
296
297
298
def on_profile_start(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during the service profile start.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_PROFILE_START, func)

on_profile_stop(func)

Decorator to specify that the function should be called during the service profile stop. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
301
302
303
304
def on_profile_stop(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during the service profile stop.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_PROFILE_STOP, func)

on_run(func)

Decorator to specify that the function should be called during run. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
275
276
277
278
def on_run(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during run.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_RUN, func)

on_set_state(func)

Decorator to specify that the function should be called when the service state is set. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
281
282
283
284
285
286
def on_set_state(
    func: Callable,
) -> Callable:
    """Decorator to specify that the function should be called when the service state is set.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_SET_STATE, func)

on_start(func)

Decorator to specify that the function should be called during start. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
251
252
253
254
def on_start(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during start.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_START, func)

on_stop(func)

Decorator to specify that the function should be called during stop. See :func:aiperf.common.hooks.hook_decorator.

Source code in aiperf/common/hooks.py
257
258
259
260
def on_stop(func: Callable) -> Callable:
    """Decorator to specify that the function should be called during stop.
    See :func:`aiperf.common.hooks.hook_decorator`."""
    return hook_decorator(AIPerfHook.ON_STOP, func)

supports_hooks(*supported_hook_types)

Decorator to indicate that a class supports hooks and sets the supported hook types.

Parameters:

Name Type Description Default
supported_hook_types HookType

The hook types that the class supports.

()

Returns:

Type Description
Callable[[type], type]

The decorated class

Source code in aiperf/common/hooks.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def supports_hooks(
    *supported_hook_types: HookType,
) -> Callable[[type], type]:
    """Decorator to indicate that a class supports hooks and sets the
    supported hook types.

    Args:
        supported_hook_types: The hook types that the class supports.

    Returns:
        The decorated class
    """

    def decorator(cls: type) -> type:
        # TODO: We can consider creating a HooksMixinProtocol, but it would still
        #       need to exist somewhere both hooks.py and mixins module can access.
        # Import this here to prevent circular imports. Also make sure you use
        # fully qualified import name to avoid partial loaded module errors.
        from aiperf.common.mixins.hooks_mixin import HooksMixin

        # Ensure the class inherits from HooksMixin
        if not issubclass(cls, HooksMixin):
            raise TypeError(f"Class {cls.__name__} does not inherit from HooksMixin.")

        # Inherit any hooks defined by base classes in the MRO (Method Resolution Order).
        base_hooks = [
            base.supported_hooks
            for base in cls.__mro__[1:]  # Skip this class itself (cls)
            if issubclass(
                base, HooksMixin
            )  # Only include classes that inherit from HooksMixin
        ]

        # Set the supported hooks to be the union of the existing base hooks and the new supported hook types.
        cls.supported_hooks = set.union(*base_hooks, set(supported_hook_types))
        return cls

    return decorator

aiperf.common.interfaces

DataExporterProtocol

Bases: Protocol

Protocol for data exporters. Any class implementing this protocol must provide an export method that takes a list of Record objects and handles exporting them appropriately.

Source code in aiperf/common/interfaces.py
14
15
16
17
18
19
20
21
22
23
24
@runtime_checkable
class DataExporterProtocol(Protocol):
    """
    Protocol for data exporters.
    Any class implementing this protocol must provide an `export` method
    that takes a list of Record objects and handles exporting them appropriately.
    """

    async def export(self) -> None:
        """Export the data."""
        ...

export() async

Export the data.

Source code in aiperf/common/interfaces.py
22
23
24
async def export(self) -> None:
    """Export the data."""
    ...

PostProcessorProtocol

Bases: Protocol

PostProcessorProtocol is a protocol that defines the API for post-processors. It requires an process method that takes a list of records and returns a result.

Source code in aiperf/common/interfaces.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class PostProcessorProtocol(Protocol):
    """
    PostProcessorProtocol is a protocol that defines the API for post-processors.
    It requires an `process` method that takes a list of records and returns a result.
    """

    def process(self, records: dict) -> dict:
        """
        Execute the post-processing logic on the given payload.

        :param payload: The input data to be processed.
        :return: The processed data as a dictionary.
        """
        pass

process(records)

Execute the post-processing logic on the given payload.

:param payload: The input data to be processed. :return: The processed data as a dictionary.

Source code in aiperf/common/interfaces.py
36
37
38
39
40
41
42
43
def process(self, records: dict) -> dict:
    """
    Execute the post-processing logic on the given payload.

    :param payload: The input data to be processed.
    :return: The processed data as a dictionary.
    """
    pass

ResponseExtractor

Bases: Protocol

Base class for all response extractors.

Source code in aiperf/common/interfaces.py
51
52
53
54
55
56
57
58
class ResponseExtractor(Protocol):
    """Base class for all response extractors."""

    async def extract_response_data(
        self, record: "ParsedResponseRecord"
    ) -> list["ResponseData"]:
        """Extract the text from a server response message."""
        ...

extract_response_data(record) async

Extract the text from a server response message.

Source code in aiperf/common/interfaces.py
54
55
56
57
58
async def extract_response_data(
    self, record: "ParsedResponseRecord"
) -> list["ResponseData"]:
    """Extract the text from a server response message."""
    ...

aiperf.common.logging

MultiProcessLogHandler

Bases: RichHandler

Custom logging handler that forwards log records to a multiprocessing queue.

Source code in aiperf/common/logging.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class MultiProcessLogHandler(RichHandler):
    """Custom logging handler that forwards log records to a multiprocessing queue."""

    def __init__(
        self, log_queue: multiprocessing.Queue, service_id: str | None = None
    ) -> None:
        super().__init__()
        self.log_queue = log_queue
        self.service_id = service_id

    def emit(self, record: logging.LogRecord) -> None:
        """Emit a log record to the queue."""
        try:
            # Create a serializable log data structure
            log_data = {
                "name": record.name,
                "levelname": record.levelname,
                "levelno": record.levelno,
                "msg": record.getMessage(),
                "created": record.created,
                "process_name": multiprocessing.current_process().name,
                "process_id": multiprocessing.current_process().pid,
                "service_id": self.service_id,
            }
            self.log_queue.put_nowait(log_data)
        except queue.Full:
            # Drop logs if queue is full to prevent blocking
            pass
        except Exception:
            # Do not log to prevent recursion
            pass

emit(record)

Emit a log record to the queue.

Source code in aiperf/common/logging.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def emit(self, record: logging.LogRecord) -> None:
    """Emit a log record to the queue."""
    try:
        # Create a serializable log data structure
        log_data = {
            "name": record.name,
            "levelname": record.levelname,
            "levelno": record.levelno,
            "msg": record.getMessage(),
            "created": record.created,
            "process_name": multiprocessing.current_process().name,
            "process_id": multiprocessing.current_process().pid,
            "service_id": self.service_id,
        }
        self.log_queue.put_nowait(log_data)
    except queue.Full:
        # Drop logs if queue is full to prevent blocking
        pass
    except Exception:
        # Do not log to prevent recursion
        pass

create_file_handler(log_folder, level)

Configure a file handler for logging.

Source code in aiperf/common/logging.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def create_file_handler(
    log_folder: Path,
    level: str | int,
) -> logging.FileHandler:
    """Configure a file handler for logging."""

    log_folder.mkdir(parents=True, exist_ok=True)
    log_file_path = log_folder / "aiperf.log"

    file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
    file_handler.setLevel(level)
    file_handler.setFormatter(
        logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
    )
    return file_handler

get_global_log_queue() cached

Get the global log queue. Will create a new queue if it doesn't exist.

Source code in aiperf/common/logging.py
23
24
25
26
@lru_cache(maxsize=1)
def get_global_log_queue() -> multiprocessing.Queue:
    """Get the global log queue. Will create a new queue if it doesn't exist."""
    return multiprocessing.Queue(maxsize=LOG_QUEUE_MAXSIZE)

setup_child_process_logging(log_queue=None, service_id=None, service_config=None, user_config=None)

Set up logging for a child process to send logs to the main process.

This should be called early in child process initialization.

Parameters:

Name Type Description Default
log_queue Queue | None

The multiprocessing queue to send logs to. If None, tries to get the global queue.

None
service_id str | None

The ID of the service to log under. If None, logs will be under the process name.

None
service_config ServiceConfig | None

The service configuration used to determine the log level.

None
user_config UserConfig | None

The user configuration used to determine the log folder.

None
Source code in aiperf/common/logging.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def setup_child_process_logging(
    log_queue: "multiprocessing.Queue | None" = None,
    service_id: str | None = None,
    service_config: ServiceConfig | None = None,
    user_config: UserConfig | None = None,
) -> None:
    """Set up logging for a child process to send logs to the main process.

    This should be called early in child process initialization.

    Args:
        log_queue: The multiprocessing queue to send logs to. If None, tries to get the global queue.
        service_id: The ID of the service to log under. If None, logs will be under the process name.
        service_config: The service configuration used to determine the log level.
        user_config: The user configuration used to determine the log folder.
    """
    root_logger = logging.getLogger()
    level = ServiceDefaults.LOG_LEVEL.upper()
    if service_config:
        level = service_config.log_level.upper()

        if service_id:
            # If the service is in the trace or debug services, set the level to trace or debug
            if service_config.trace_services and _is_service_in_types(
                service_id, service_config.trace_services
            ):
                level = _TRACE
            elif service_config.debug_services and _is_service_in_types(
                service_id, service_config.debug_services
            ):
                level = _DEBUG

    # Set the root logger level to ensure logs are passed to handlers
    root_logger.setLevel(level)

    # Remove all existing handlers to avoid duplicate logs
    for existing_handler in root_logger.handlers[:]:
        root_logger.removeHandler(existing_handler)

    if log_queue is not None:
        # Set up handler for child process
        queue_handler = MultiProcessLogHandler(log_queue, service_id)
        queue_handler.setLevel(level)
        root_logger.addHandler(queue_handler)

    if service_config:
        # Set up rich logging to the console
        rich_handler = RichHandler(
            rich_tracebacks=True,
            show_path=True,
            console=Console(),
            show_time=True,
            show_level=True,
            tracebacks_show_locals=False,
            log_time_format="%H:%M:%S.%f",
            omit_repeated_times=False,
        )
        rich_handler.setLevel(level)
        root_logger.addHandler(rich_handler)

    if user_config and user_config.output.artifact_directory:
        file_handler = create_file_handler(
            user_config.output.artifact_directory / "logs", level
        )
        root_logger.addHandler(file_handler)

setup_rich_logging(user_config, service_config)

Set up rich logging with appropriate configuration.

Source code in aiperf/common/logging.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def setup_rich_logging(user_config: UserConfig, service_config: ServiceConfig) -> None:
    """Set up rich logging with appropriate configuration."""
    # Set logging level for the root logger (affects all loggers)
    level = service_config.log_level.upper()
    logging.root.setLevel(level)

    rich_handler = RichHandler(
        rich_tracebacks=True,
        show_path=True,
        console=Console(),
        show_time=True,
        show_level=True,
        tracebacks_show_locals=False,
        log_time_format="%H:%M:%S.%f",
        omit_repeated_times=False,
    )
    logging.root.addHandler(rich_handler)

    # Enable file logging for services
    # TODO: Use config to determine if file logging is enabled and the folder path.
    log_folder = user_config.output.artifact_directory / "logs"
    log_folder.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_folder / "aiperf.log")
    file_handler.setLevel(level)
    file_handler.formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logging.root.addHandler(file_handler)

    logger.debug(lambda: f"Logging initialized with level: {level}")

aiperf.common.messages.base_messages

ErrorMessage

Bases: Message

Message containing error data.

Source code in aiperf/common/messages/base_messages.py
131
132
133
134
135
136
class ErrorMessage(Message):
    """Message containing error data."""

    message_type: MessageTypeT = MessageType.ERROR

    error: ErrorDetails = Field(..., description="Error information")

Message

Bases: ExcludeIfNoneMixin

Base message class for optimized message handling.

This class provides a base for all messages, including common fields like message_type, request_ns, and request_id. It also supports optional field exclusion based on the @exclude_if_none decorator.

Each message model should inherit from this class, set the message_type field, and define its own additional fields. Optionally, the @exclude_if_none decorator can be used to specify which fields should be excluded from the serialized message if they are None.

Example:

@exclude_if_none(["some_field"])
class ExampleMessage(Message):
    some_field: int | None = Field(default=None)
    other_field: int = Field(default=1)
Source code in aiperf/common/messages/base_messages.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@exclude_if_none(["request_ns", "request_id"])
class Message(ExcludeIfNoneMixin):
    """Base message class for optimized message handling.

    This class provides a base for all messages, including common fields like message_type,
    request_ns, and request_id. It also supports optional field exclusion based on the
    @exclude_if_none decorator.

    Each message model should inherit from this class, set the message_type field,
    and define its own additional fields.
    Optionally, the @exclude_if_none decorator can be used to specify which fields
    should be excluded from the serialized message if they are None.

    Example:
    ```python
    @exclude_if_none(["some_field"])
    class ExampleMessage(Message):
        some_field: int | None = Field(default=None)
        other_field: int = Field(default=1)
    ```
    """

    _exclude_if_none_fields: ClassVar[set[str]] = set()
    """Set of field names that should be excluded from the serialized message if they
    are None. This is set by the @exclude_if_none decorator.
    """

    _message_type_lookup: ClassVar[dict[MessageTypeT, type["Message"]]] = {}

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if hasattr(cls, "message_type"):
            cls._message_type_lookup[cls.message_type] = cls

    message_type: MessageTypeT = Field(
        ...,
        description="The type of the message. Must be set in the subclass.",
    )

    request_ns: int | None = Field(
        default=None,
        description="Timestamp of the request",
    )

    request_id: str | None = Field(
        default=None,
        description="ID of the request",
    )

    @model_serializer
    def _serialize_message(self) -> dict[str, Any]:
        """Serialize the message to a dictionary.

        This method overrides the default serializer to exclude fields that have a
        value of None and have the EXCLUDE_IF_NONE json_schema_extra key set to True.
        """
        return {
            k: v
            for k, v in self
            if not (k in self._exclude_if_none_fields and v is None)
        }

    @classmethod
    def __get_validators__(cls):
        yield cls.from_json

    @classmethod
    def from_json(cls, json_str: str | bytes | bytearray) -> "Message":
        """Deserialize a message from a JSON string, attempting to auto-detect the message type.
        NOTE: If you already know the message type, use the more performant :meth:`from_json_with_type` instead."""
        data = json.loads(json_str)
        message_type = data.get("message_type")
        if not message_type:
            raise ValueError(f"Missing message_type: {json_str}")

        # Use cached message type lookup
        message_class = cls._message_type_lookup.get(message_type)
        if message_class is None:
            raise ValueError(f"Unknown message type: {message_type}")

        return message_class.model_validate(data)

    @classmethod
    def from_json_with_type(
        cls, message_type: MessageTypeT, json_str: str | bytes | bytearray
    ) -> "Message":
        """Deserialize a message from a JSON string with a specific message type.
        NOTE: This is more performant than :meth:`from_json` because it does not need to
        convert the JSON string to a dictionary first."""
        # Use cached message type lookup
        message_class = cls._message_type_lookup.get(message_type)
        if message_class is None:
            raise ValueError(f"Unknown message type: {message_type}")
        return message_class.model_validate_json(json_str)

    def to_json(self) -> str:
        """Fast serialization without full validation"""
        return orjson.dumps(self.__dict__).decode("utf-8")

from_json(json_str) classmethod

Deserialize a message from a JSON string, attempting to auto-detect the message type. NOTE: If you already know the message type, use the more performant :meth:from_json_with_type instead.

Source code in aiperf/common/messages/base_messages.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@classmethod
def from_json(cls, json_str: str | bytes | bytearray) -> "Message":
    """Deserialize a message from a JSON string, attempting to auto-detect the message type.
    NOTE: If you already know the message type, use the more performant :meth:`from_json_with_type` instead."""
    data = json.loads(json_str)
    message_type = data.get("message_type")
    if not message_type:
        raise ValueError(f"Missing message_type: {json_str}")

    # Use cached message type lookup
    message_class = cls._message_type_lookup.get(message_type)
    if message_class is None:
        raise ValueError(f"Unknown message type: {message_type}")

    return message_class.model_validate(data)

from_json_with_type(message_type, json_str) classmethod

Deserialize a message from a JSON string with a specific message type. NOTE: This is more performant than :meth:from_json because it does not need to convert the JSON string to a dictionary first.

Source code in aiperf/common/messages/base_messages.py
104
105
106
107
108
109
110
111
112
113
114
115
@classmethod
def from_json_with_type(
    cls, message_type: MessageTypeT, json_str: str | bytes | bytearray
) -> "Message":
    """Deserialize a message from a JSON string with a specific message type.
    NOTE: This is more performant than :meth:`from_json` because it does not need to
    convert the JSON string to a dictionary first."""
    # Use cached message type lookup
    message_class = cls._message_type_lookup.get(message_type)
    if message_class is None:
        raise ValueError(f"Unknown message type: {message_type}")
    return message_class.model_validate_json(json_str)

to_json()

Fast serialization without full validation

Source code in aiperf/common/messages/base_messages.py
117
118
119
def to_json(self) -> str:
    """Fast serialization without full validation"""
    return orjson.dumps(self.__dict__).decode("utf-8")

RequiresRequestNSMixin

Bases: Message

Mixin for messages that require a request_ns field.

Source code in aiperf/common/messages/base_messages.py
122
123
124
125
126
127
128
class RequiresRequestNSMixin(Message):
    """Mixin for messages that require a request_ns field."""

    request_ns: int = Field(  # type: ignore[assignment]
        default_factory=time.time_ns,
        description="Timestamp of the request in nanoseconds",
    )

aiperf.common.messages.command_messages

CommandMessage

Bases: BaseServiceMessage

Message containing command data. This message is sent by the system controller to a service to command it to do something.

Source code in aiperf/common/messages/command_messages.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class CommandMessage(BaseServiceMessage):
    """Message containing command data.
    This message is sent by the system controller to a service to command it to do something.
    """

    message_type: MessageTypeT = MessageType.COMMAND

    command: CommandType = Field(
        ...,
        description="Command to execute",
    )
    command_id: str = Field(
        default_factory=lambda: str(uuid.uuid4()),
        description="Unique identifier for this command. If not provided, a random UUID will be generated.",
    )
    require_response: bool = Field(
        default=False,
        description="Whether a response is required for this command",
    )
    target_service_type: ServiceType | None = Field(
        default=None,
        description="Type of the service to send the command to. "
        "If both `target_service_type` and `target_service_id` are None, the command is "
        "sent to all services.",
    )
    target_service_id: str | None = Field(
        default=None,
        description="ID of the target service to send the command to. "
        "If both `target_service_type` and `target_service_id` are None, the command is "
        "sent to all services.",
    )
    data: SerializeAsAny[ProcessRecordsCommandData | BaseModel | None] = Field(
        default=None,
        description="Data to send with the command",
    )

CommandResponseMessage

Bases: BaseServiceMessage

Message containing a command response. This message is sent by a component service to the system controller to respond to a command.

Source code in aiperf/common/messages/command_messages.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class CommandResponseMessage(BaseServiceMessage):
    """Message containing a command response.
    This message is sent by a component service to the system controller to respond to a command.
    """

    message_type: MessageTypeT = MessageType.COMMAND_RESPONSE

    command: CommandType = Field(
        ...,
        description="Command type that is being responded to",
    )
    command_id: str = Field(
        ..., description="The ID of the command that is being responded to"
    )
    status: CommandResponseStatus = Field(..., description="The status of the command")
    data: SerializeAsAny[BaseModel | None] = Field(
        default=None,
        description="Data to send with the command response if the command succeeded",
    )
    error: ErrorDetails | None = Field(
        default=None,
        description="Error information if the command failed",
    )

ProcessRecordsCommandData

Bases: BaseModel

Data to send with the process records command.

Source code in aiperf/common/messages/command_messages.py
24
25
26
27
28
29
30
class ProcessRecordsCommandData(BaseModel):
    """Data to send with the process records command."""

    cancelled: bool = Field(
        default=False,
        description="Whether the profile run was cancelled",
    )

aiperf.common.messages.credit_messages

CreditDropMessage

Bases: BaseServiceMessage

Message indicating that a credit has been dropped. This message is sent by the timing manager to workers to indicate that credit(s) have been dropped.

Source code in aiperf/common/messages/credit_messages.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class CreditDropMessage(BaseServiceMessage):
    """Message indicating that a credit has been dropped.
    This message is sent by the timing manager to workers to indicate that credit(s)
    have been dropped.
    """

    message_type: MessageTypeT = MessageType.CREDIT_DROP

    phase: CreditPhase = Field(..., description="The type of credit phase")
    conversation_id: str | None = Field(
        default=None, description="The ID of the conversation, if applicable."
    )
    credit_drop_ns: int | None = Field(
        default=None,
        description="Timestamp of the credit drop, if applicable. None means send ASAP.",
    )

CreditPhaseCompleteMessage

Bases: BaseServiceMessage

Message for credit phase complete. Sent by the TimingManager to report that a credit phase has completed.

Source code in aiperf/common/messages/credit_messages.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class CreditPhaseCompleteMessage(BaseServiceMessage):
    """Message for credit phase complete. Sent by the TimingManager to report that a credit phase has completed."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_COMPLETE
    phase: CreditPhase = Field(..., description="The type of credit phase")
    completed: int = Field(
        ...,
        description="The number of completed credits (returned from the workers). This is the final count of completed credits.",
    )
    end_ns: int | None = Field(
        default=None,
        ge=1,
        description="The time in which the last credit was returned from the workers in nanoseconds. If None, the phase has not completed.",
    )

CreditPhaseProgressMessage

Bases: BaseServiceMessage

Sent by the TimingManager to report the progress of a credit phase.

Source code in aiperf/common/messages/credit_messages.py
81
82
83
84
85
86
87
88
89
90
class CreditPhaseProgressMessage(BaseServiceMessage):
    """Sent by the TimingManager to report the progress of a credit phase."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_PROGRESS
    phase: CreditPhase = Field(..., description="The type of credit phase")
    sent: int = Field(default=0, description="The number of sent credits")
    completed: int = Field(
        default=0,
        description="The number of completed credits (returned from the workers)",
    )

CreditPhaseSendingCompleteMessage

Bases: BaseServiceMessage

Message for credit phase sending complete. Sent by the TimingManager to report that a credit phase has completed sending.

Source code in aiperf/common/messages/credit_messages.py
 93
 94
 95
 96
 97
 98
 99
100
101
class CreditPhaseSendingCompleteMessage(BaseServiceMessage):
    """Message for credit phase sending complete. Sent by the TimingManager to report that a credit phase has completed sending."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_SENDING_COMPLETE
    phase: CreditPhase = Field(..., description="The type of credit phase")
    sent_end_ns: int | None = Field(
        default=None,
        description="The time of the last sent credit in nanoseconds. If None, the phase has not sent all credits.",
    )

CreditPhaseStartMessage

Bases: BaseServiceMessage

Message for credit phase start. Sent by the TimingManager to report that a credit phase has started.

Source code in aiperf/common/messages/credit_messages.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class CreditPhaseStartMessage(BaseServiceMessage):
    """Message for credit phase start. Sent by the TimingManager to report that a credit phase has started."""

    message_type: MessageTypeT = MessageType.CREDIT_PHASE_START
    phase: CreditPhase = Field(..., description="The type of credit phase")
    start_ns: int = Field(
        ge=1,
        description="The start time of the credit phase in nanoseconds.",
    )
    total_expected_requests: int | None = Field(
        default=None,
        ge=1,
        description="The total number of expected requests. If None, the phase is not request count based.",
    )
    expected_duration_sec: float | None = Field(
        default=None,
        ge=1,
        description="The expected duration of the credit phase in seconds. If None, the phase is not time based.",
    )

CreditReturnMessage

Bases: BaseServiceMessage

Message indicating that a credit has been returned. This message is sent by a worker to the timing manager to indicate that work has been completed.

Source code in aiperf/common/messages/credit_messages.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class CreditReturnMessage(BaseServiceMessage):
    """Message indicating that a credit has been returned.
    This message is sent by a worker to the timing manager to indicate that work has
    been completed.
    """

    message_type: MessageTypeT = MessageType.CREDIT_RETURN

    phase: CreditPhase = Field(
        ...,
        description="The Credit Phase of the credit drop. This is so the TimingManager can track the progress of the credit phase.",
    )
    delayed_ns: int | None = Field(
        default=None,
        ge=1,
        description="The number of nanoseconds the credit drop was delayed by, or None if the credit was sent on time. "
        "NOTE: This is only applicable if the original credit_drop_ns was not None.",
    )
    # TODO: Does it make more sense for this to be part of the RequestRecord?
    pre_inference_ns: int | None = Field(
        default=None,
        description="The latency of the credit in nanoseconds from when it was first received to when the inference request was sent. "
        "This can be used to trace the latency in order to identify bottlenecks or other issues.",
        ge=0,
    )

    @property
    def delayed(self) -> bool:
        return self.delayed_ns is not None

CreditsCompleteMessage

Bases: BaseServiceMessage

Credits complete message sent by the TimingManager to the System controller to signify all Credit Phases have been completed.

Source code in aiperf/common/messages/credit_messages.py
120
121
122
123
124
class CreditsCompleteMessage(BaseServiceMessage):
    """Credits complete message sent by the TimingManager to the System controller to signify all Credit Phases
    have been completed."""

    message_type: MessageTypeT = MessageType.CREDITS_COMPLETE

aiperf.common.messages.dataset_messages

ConversationRequestMessage

Bases: BaseServiceMessage

Message to request a full conversation by ID.

Source code in aiperf/common/messages/dataset_messages.py
12
13
14
15
16
17
18
19
20
21
22
23
class ConversationRequestMessage(BaseServiceMessage):
    """Message to request a full conversation by ID."""

    message_type: MessageTypeT = MessageType.CONVERSATION_REQUEST

    conversation_id: str | None = Field(
        default=None, description="The session ID of the conversation"
    )
    credit_phase: CreditPhase | None = Field(
        default=None,
        description="The type of credit phase (either warmup or profiling). If not provided, the timing manager will use the default credit phase.",
    )

ConversationResponseMessage

Bases: BaseServiceMessage

Message containing a full conversation.

Source code in aiperf/common/messages/dataset_messages.py
26
27
28
29
30
class ConversationResponseMessage(BaseServiceMessage):
    """Message containing a full conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_RESPONSE
    conversation: Conversation = Field(..., description="The conversation data")

ConversationTurnRequestMessage

Bases: BaseServiceMessage

Message to request a single turn from a conversation.

Source code in aiperf/common/messages/dataset_messages.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ConversationTurnRequestMessage(BaseServiceMessage):
    """Message to request a single turn from a conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_TURN_REQUEST

    conversation_id: str = Field(
        ...,
        description="The ID of the conversation.",
    )
    turn_index: int = Field(
        ...,
        ge=0,
        description="The index of the turn in the conversation.",
    )

ConversationTurnResponseMessage

Bases: BaseServiceMessage

Message containing a single turn from a conversation.

Source code in aiperf/common/messages/dataset_messages.py
49
50
51
52
53
54
class ConversationTurnResponseMessage(BaseServiceMessage):
    """Message containing a single turn from a conversation."""

    message_type: MessageTypeT = MessageType.CONVERSATION_TURN_RESPONSE

    turn: Turn = Field(..., description="The turn data")

DatasetConfiguredNotification

Bases: BaseServiceMessage

Notification sent to notify other services that the dataset has been configured.

Source code in aiperf/common/messages/dataset_messages.py
74
75
76
77
class DatasetConfiguredNotification(BaseServiceMessage):
    """Notification sent to notify other services that the dataset has been configured."""

    message_type: MessageTypeT = MessageType.DATASET_CONFIGURED_NOTIFICATION

DatasetTimingRequest

Bases: BaseServiceMessage

Message for a dataset timing request.

Source code in aiperf/common/messages/dataset_messages.py
57
58
59
60
class DatasetTimingRequest(BaseServiceMessage):
    """Message for a dataset timing request."""

    message_type: MessageTypeT = MessageType.DATASET_TIMING_REQUEST

DatasetTimingResponse

Bases: BaseServiceMessage

Message for a dataset timing response.

Source code in aiperf/common/messages/dataset_messages.py
63
64
65
66
67
68
69
70
71
class DatasetTimingResponse(BaseServiceMessage):
    """Message for a dataset timing response."""

    message_type: MessageTypeT = MessageType.DATASET_TIMING_RESPONSE

    timing_data: list[tuple[int, str]] = Field(
        ...,
        description="The timing data of the dataset. Tuple of (timestamp, conversation_id)",
    )

aiperf.common.messages.health_messages

WorkerHealthMessage

Bases: BaseServiceMessage

Message for a worker health check.

Source code in aiperf/common/messages/health_messages.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class WorkerHealthMessage(BaseServiceMessage):
    """Message for a worker health check."""

    message_type: MessageTypeT = MessageType.WORKER_HEALTH

    process: ProcessHealth = Field(..., description="The health of the worker process")

    # Worker specific fields
    task_stats: dict[CreditPhase, WorkerPhaseTaskStats] = Field(
        ...,
        description="Stats for the tasks that have been sent to the worker, keyed by the credit phase",
    )

    @property
    def total_tasks(self) -> int:
        """The total number of tasks that have been sent to the worker."""
        return sum(task_stats.total for task_stats in self.task_stats.values())

    @property
    def completed_tasks(self) -> int:
        """The number of tasks that have been completed by the worker."""
        return sum(task_stats.completed for task_stats in self.task_stats.values())

    @property
    def failed_tasks(self) -> int:
        """The number of tasks that have failed by the worker."""
        return sum(task_stats.failed for task_stats in self.task_stats.values())

    @property
    def in_progress_tasks(self) -> int:
        """The number of tasks that are currently in progress by the worker."""
        return sum(task_stats.in_progress for task_stats in self.task_stats.values())

    @property
    def error_rate(self) -> float:
        """The error rate of the worker."""
        if self.total_tasks == 0:
            return 0
        return self.failed_tasks / self.total_tasks

completed_tasks property

The number of tasks that have been completed by the worker.

error_rate property

The error rate of the worker.

failed_tasks property

The number of tasks that have failed by the worker.

in_progress_tasks property

The number of tasks that are currently in progress by the worker.

total_tasks property

The total number of tasks that have been sent to the worker.

aiperf.common.messages.inference_messages

InferenceResultsMessage

Bases: BaseServiceMessage

Message for a inference results.

Source code in aiperf/common/messages/inference_messages.py
20
21
22
23
24
25
26
27
class InferenceResultsMessage(BaseServiceMessage):
    """Message for a inference results."""

    message_type: MessageTypeT = MessageType.INFERENCE_RESULTS

    record: SerializeAsAny[RequestRecord] = Field(
        ..., description="The inference results record"
    )

ParsedInferenceResultsMessage

Bases: BaseServiceMessage

Message for a parsed inference results.

Source code in aiperf/common/messages/inference_messages.py
30
31
32
33
34
35
36
37
class ParsedInferenceResultsMessage(BaseServiceMessage):
    """Message for a parsed inference results."""

    message_type: MessageTypeT = MessageType.PARSED_INFERENCE_RESULTS

    record: SerializeAsAny[ParsedResponseRecord] = Field(
        ..., description="The post process results record"
    )

aiperf.common.messages.progress_messages

ProcessingStatsMessage

Bases: BaseServiceMessage

Message for processing stats. Sent by the records manager to the system controller to report the stats of the profile run.

Source code in aiperf/common/messages/progress_messages.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ProcessingStatsMessage(BaseServiceMessage):
    """Message for processing stats. Sent by the records manager to the system controller to report the stats of the profile run."""

    message_type: MessageTypeT = MessageType.PROCESSING_STATS

    error_count: int = Field(default=0, description="The number of errors encountered")
    completed: int = Field(
        default=0, description="The number of requests processed by the records manager"
    )
    worker_completed: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker request completion counts, keyed by worker service_id",
    )
    worker_errors: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker error counts, keyed by worker service_id",
    )

ProfileProgressMessage

Bases: BaseServiceMessage

Message for profile progress. Sent by the timing manager to the system controller to report the progress of the profile run.

Source code in aiperf/common/messages/progress_messages.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ProfileProgressMessage(BaseServiceMessage):
    """Message for profile progress. Sent by the timing manager to the system controller to report the progress of the profile run."""

    message_type: MessageTypeT = MessageType.PROFILE_PROGRESS

    profile_id: str | None = Field(
        default=None, description="The ID of the current profile"
    )
    start_ns: int = Field(
        ..., description="The start time of the profile run in nanoseconds"
    )
    end_ns: int | None = Field(
        default=None, description="The end time of the profile run in nanoseconds"
    )
    total: int = Field(
        ..., description="The total number of inference requests to be made (if known)"
    )
    completed: int = Field(
        ..., description="The number of inference requests completed"
    )
    warmup: bool = Field(
        default=False,
        description="Whether this is the warmup phase of the profile run",
    )

ProfileResultsMessage

Bases: BaseServiceMessage

Message for profile results.

Source code in aiperf/common/messages/progress_messages.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class ProfileResultsMessage(BaseServiceMessage):
    """Message for profile results."""

    message_type: MessageTypeT = MessageType.PROFILE_RESULTS

    records: SerializeAsAny[list[MetricResult]] = Field(
        ..., description="The records of the profile results"
    )
    total: int = Field(
        ...,
        description="The total number of inference requests expected to be made (if known)",
    )
    completed: int = Field(
        ..., description="The number of inference requests completed"
    )
    start_ns: int = Field(
        ..., description="The start time of the profile run in nanoseconds"
    )
    end_ns: int = Field(
        ..., description="The end time of the profile run in nanoseconds"
    )
    was_cancelled: bool = Field(
        default=False,
        description="Whether the profile run was cancelled early",
    )
    errors_by_type: list[ErrorDetailsCount] = Field(
        default_factory=list,
        description="A list of the unique error details and their counts",
    )

RecordsProcessingStatsMessage

Bases: BaseServiceMessage

Message for processing stats. Sent by the RecordsManager to report the stats of the profile run. This contains the stats for a single credit phase only.

Source code in aiperf/common/messages/progress_messages.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class RecordsProcessingStatsMessage(BaseServiceMessage):
    """Message for processing stats. Sent by the RecordsManager to report the stats of the profile run.
    This contains the stats for a single credit phase only."""

    message_type: MessageTypeT = MessageType.PROCESSING_STATS

    processing_stats: PhaseProcessingStats = Field(
        ..., description="The stats for the credit phase"
    )
    worker_stats: dict[str, PhaseProcessingStats] = Field(
        default_factory=dict,
        description="The stats for each worker how many requests were processed and how many errors were "
        "encountered, keyed by worker service_id",
    )

SweepProgressMessage

Bases: BaseServiceMessage

Message for sweep progress.

Source code in aiperf/common/messages/progress_messages.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class SweepProgressMessage(BaseServiceMessage):
    """Message for sweep progress."""

    # TODO: add profile information

    message_type: MessageTypeT = MessageType.SWEEP_PROGRESS

    sweep_id: str = Field(..., description="The ID of the current sweep")
    sweep_start_ns: int = Field(
        ..., description="The start time of the sweep in nanoseconds"
    )
    end_ns: int | None = Field(
        default=None, description="The end time of the profile run in nanoseconds"
    )

aiperf.common.messages.service_messages

BaseServiceErrorMessage

Bases: BaseServiceMessage

Base message containing error data.

Source code in aiperf/common/messages/service_messages.py
 97
 98
 99
100
101
102
class BaseServiceErrorMessage(BaseServiceMessage):
    """Base message containing error data."""

    message_type: MessageTypeT = MessageType.SERVICE_ERROR

    error: ErrorDetails = Field(..., description="Error information")

BaseServiceMessage

Bases: Message

Base message that is sent from a service. Requires a service_id field to specify the service that sent the message.

Source code in aiperf/common/messages/service_messages.py
24
25
26
27
28
29
30
31
class BaseServiceMessage(Message):
    """Base message that is sent from a service. Requires a service_id field to specify
    the service that sent the message."""

    service_id: str = Field(
        ...,
        description="ID of the service sending the message",
    )

BaseStatusMessage

Bases: BaseServiceMessage

Base message containing status data. This message is sent by a service to the system controller to report its status.

Source code in aiperf/common/messages/service_messages.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class BaseStatusMessage(BaseServiceMessage):
    """Base message containing status data.
    This message is sent by a service to the system controller to report its status.
    """

    # override request_ns to be auto-filled if not provided
    request_ns: int | None = Field(
        default=time.time_ns(),
        description="Timestamp of the request",
    )
    state: ServiceState = Field(
        ...,
        description="Current state of the service",
    )
    service_type: ServiceType = Field(
        ...,
        description="Type of service",
    )

HeartbeatMessage

Bases: BaseStatusMessage

Message containing heartbeat data. This message is sent by a service to the system controller to indicate that it is still running.

Source code in aiperf/common/messages/service_messages.py
72
73
74
75
76
77
78
class HeartbeatMessage(BaseStatusMessage):
    """Message containing heartbeat data.
    This message is sent by a service to the system controller to indicate that it is
    still running.
    """

    message_type: MessageTypeT = MessageType.HEARTBEAT

NotificationMessage

Bases: BaseServiceMessage

Message containing a notification from a service. This is used to notify other services of events.

Source code in aiperf/common/messages/service_messages.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class NotificationMessage(BaseServiceMessage):
    """Message containing a notification from a service. This is used to notify other services of events."""

    message_type: MessageTypeT = MessageType.NOTIFICATION

    notification_type: NotificationType = Field(
        ...,
        description="The type of notification",
    )

    data: SerializeAsAny[BaseModel | None] = Field(
        default=None,
        description="Data to send with the notification",
    )

RegistrationMessage

Bases: BaseStatusMessage

Message containing registration data. This message is sent by a service to the system controller to register itself.

Source code in aiperf/common/messages/service_messages.py
62
63
64
65
66
67
68
69
class RegistrationMessage(BaseStatusMessage):
    """Message containing registration data.
    This message is sent by a service to the system controller to register itself.
    """

    message_type: MessageTypeT = MessageType.REGISTRATION

    state: ServiceState = ServiceState.READY

StatusMessage

Bases: BaseStatusMessage

Message containing status data. This message is sent by a service to the system controller to report its status.

Source code in aiperf/common/messages/service_messages.py
54
55
56
57
58
59
class StatusMessage(BaseStatusMessage):
    """Message containing status data.
    This message is sent by a service to the system controller to report its status.
    """

    message_type: MessageTypeT = MessageType.STATUS

aiperf.common.mixins.aiperf_lifecycle_mixin

AIPerfLifecycleMixin

Bases: HooksMixin, AsyncTaskManagerMixin, AIPerfLoggerMixin

Mixin to add task support to a class. It abstracts away the details of the :class:AIPerfTask and provides a simple interface for registering and running tasks. It hooks into the :meth:HooksMixin.on_start and :meth:HooksMixin.on_stop hooks to start and stop the tasks.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@supports_hooks(
    AIPerfTaskHook.AIPERF_TASK,
    AIPerfTaskHook.AIPERF_AUTO_TASK,
    AIPerfHook.ON_INIT,
    AIPerfHook.ON_START,
    AIPerfHook.ON_STOP,
    AIPerfHook.ON_CLEANUP,
)
class AIPerfLifecycleMixin(HooksMixin, AsyncTaskManagerMixin, AIPerfLoggerMixin):
    """Mixin to add task support to a class. It abstracts away the details of the
    :class:`AIPerfTask` and provides a simple interface for registering and running tasks.
    It hooks into the :meth:`HooksMixin.on_start` and :meth:`HooksMixin.on_stop` hooks to
    start and stop the tasks.
    """

    def __init__(self, **kwargs):
        self.initialized_event: asyncio.Event = asyncio.Event()
        self.started_event: asyncio.Event = asyncio.Event()
        self.stop_requested: asyncio.Event = asyncio.Event()
        self.shutdown_event: asyncio.Event = asyncio.Event()
        self.lifecycle_task: asyncio.Task | None = None
        super().__init__(**kwargs)

    def is_initialized(self) -> bool:
        """Check if the lifecycle has been initialized."""
        return self.initialized_event.is_set()

    async def _run_lifecycle(self) -> None:
        """Run the internal lifecycle."""
        # Run all the initialization hooks and set the initialize_event
        await self.run_hooks(AIPerfHook.ON_INIT)
        self.initialized_event.set()

        # Run all the start hooks and set the start_event
        await self.run_hooks_async(AIPerfHook.ON_START)
        self.started_event.set()

        while not self.stop_requested.is_set() and not self.shutdown_event.is_set():
            try:
                # Wait forever until the stop_requested event is set
                await self.stop_requested.wait()
            except asyncio.CancelledError:
                break
            except Exception as e:
                self.logger.exception("Unhandled exception in lifecycle: %s", e)
                continue

        try:
            # Run all the stop hooks
            await self.run_hooks_async(AIPerfHook.ON_STOP)
        except Exception as e:
            self.logger.exception("Unhandled exception in lifecycle: %s", e)

        try:
            # Run all the cleanup hooks and set the shutdown_event
            await self.run_hooks(AIPerfHook.ON_CLEANUP)
        except Exception as e:
            self.logger.exception("Unhandled exception in lifecycle: %s", e)
        finally:
            self.shutdown_event.set()

        self.trace("Lifecycle finished")

    async def run_async(self) -> None:
        """Start the lifecycle in the background. Will call the :meth:`HooksMixin.on_init` hooks,
        followed by the :meth:`HooksMixin.on_start` hooks. Will return immediately."""
        if self.lifecycle_task is not None:
            raise InvalidStateError("Lifecycle is already running")
        self.lifecycle_task = asyncio.create_task(self._run_lifecycle())

    async def run_and_wait_for_start(self) -> None:
        """Start the lifecycle in the background and wait until the lifecycle is initialized and started.
        Will call the :meth:`HooksMixin.on_init` hooks, followed by the :meth:`HooksMixin.on_start` hooks."""
        if self.lifecycle_task is not None:
            raise InvalidStateError("Lifecycle is already running")
        self.lifecycle_task = asyncio.create_task(self._run_lifecycle())

        await self.initialized_event.wait()
        await self.started_event.wait()

    async def wait_for_initialize(self) -> None:
        """Wait for the lifecycle to be initialized. Will wait until the :meth:`HooksMixin.on_init` hooks have been called.
        Will return immediately if the lifecycle is already initialized."""
        await self.initialized_event.wait()

    async def wait_for_start(self) -> None:
        """Wait for the lifecycle to be started. Will wait until the :meth:`HooksMixin.on_start` hooks have been called.
        Will return immediately if the lifecycle is already started."""
        await self.started_event.wait()

    async def wait_for_shutdown(self) -> None:
        """Wait for the lifecycle to be shutdown. Will wait until the :meth:`HooksMixin.on_stop` hooks have been called.
        Will return immediately if the lifecycle is already shutdown."""
        await self.shutdown_event.wait()

    async def shutdown(self) -> None:
        """Shutdown the lifecycle. Will call the :meth:`HooksMixin.on_stop` hooks,
        followed by the :meth:`HooksMixin.on_cleanup` hooks."""
        self.stop_requested.set()

    @on_start
    async def _start_tasks(self):
        """Start all the registered tasks in the background."""

        # Start all the non-auto tasks
        for hook in self.get_hooks(AIPerfTaskHook.AIPERF_TASK):
            if inspect.iscoroutinefunction(hook):
                self.execute_async(hook())
            else:
                self.execute_async(asyncio.to_thread(hook))

        # Start all the auto tasks
        for hook in self.get_hooks(AIPerfTaskHook.AIPERF_AUTO_TASK):
            interval = getattr(hook, AIPerfTaskHook.AIPERF_AUTO_TASK_INTERVAL, None)
            self.execute_async(self._task_wrapper(hook, interval))

    @on_stop
    async def _stop_tasks(self):
        """Stop all the background tasks. This will wait for all the tasks to complete."""
        await self.cancel_all_tasks()

    async def _task_wrapper(
        self, func: Callable, interval: float | None = None
    ) -> None:
        """Wrapper to run a task in a loop until the stop_requested event is set."""
        while not self.stop_requested.is_set():
            try:
                if inspect.iscoroutinefunction(func):
                    await func()
                else:
                    await asyncio.to_thread(func)
            except asyncio.CancelledError:
                break
            except Exception:
                self.logger.exception("Unhandled exception in task: %s", func.__name__)

            if interval is None:
                break
            await asyncio.sleep(interval)

is_initialized()

Check if the lifecycle has been initialized.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
43
44
45
def is_initialized(self) -> bool:
    """Check if the lifecycle has been initialized."""
    return self.initialized_event.is_set()

run_and_wait_for_start() async

Start the lifecycle in the background and wait until the lifecycle is initialized and started. Will call the :meth:HooksMixin.on_init hooks, followed by the :meth:HooksMixin.on_start hooks.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
90
91
92
93
94
95
96
97
98
async def run_and_wait_for_start(self) -> None:
    """Start the lifecycle in the background and wait until the lifecycle is initialized and started.
    Will call the :meth:`HooksMixin.on_init` hooks, followed by the :meth:`HooksMixin.on_start` hooks."""
    if self.lifecycle_task is not None:
        raise InvalidStateError("Lifecycle is already running")
    self.lifecycle_task = asyncio.create_task(self._run_lifecycle())

    await self.initialized_event.wait()
    await self.started_event.wait()

run_async() async

Start the lifecycle in the background. Will call the :meth:HooksMixin.on_init hooks, followed by the :meth:HooksMixin.on_start hooks. Will return immediately.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
83
84
85
86
87
88
async def run_async(self) -> None:
    """Start the lifecycle in the background. Will call the :meth:`HooksMixin.on_init` hooks,
    followed by the :meth:`HooksMixin.on_start` hooks. Will return immediately."""
    if self.lifecycle_task is not None:
        raise InvalidStateError("Lifecycle is already running")
    self.lifecycle_task = asyncio.create_task(self._run_lifecycle())

shutdown() async

Shutdown the lifecycle. Will call the :meth:HooksMixin.on_stop hooks, followed by the :meth:HooksMixin.on_cleanup hooks.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
115
116
117
118
async def shutdown(self) -> None:
    """Shutdown the lifecycle. Will call the :meth:`HooksMixin.on_stop` hooks,
    followed by the :meth:`HooksMixin.on_cleanup` hooks."""
    self.stop_requested.set()

wait_for_initialize() async

Wait for the lifecycle to be initialized. Will wait until the :meth:HooksMixin.on_init hooks have been called. Will return immediately if the lifecycle is already initialized.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
100
101
102
103
async def wait_for_initialize(self) -> None:
    """Wait for the lifecycle to be initialized. Will wait until the :meth:`HooksMixin.on_init` hooks have been called.
    Will return immediately if the lifecycle is already initialized."""
    await self.initialized_event.wait()

wait_for_shutdown() async

Wait for the lifecycle to be shutdown. Will wait until the :meth:HooksMixin.on_stop hooks have been called. Will return immediately if the lifecycle is already shutdown.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
110
111
112
113
async def wait_for_shutdown(self) -> None:
    """Wait for the lifecycle to be shutdown. Will wait until the :meth:`HooksMixin.on_stop` hooks have been called.
    Will return immediately if the lifecycle is already shutdown."""
    await self.shutdown_event.wait()

wait_for_start() async

Wait for the lifecycle to be started. Will wait until the :meth:HooksMixin.on_start hooks have been called. Will return immediately if the lifecycle is already started.

Source code in aiperf/common/mixins/aiperf_lifecycle_mixin.py
105
106
107
108
async def wait_for_start(self) -> None:
    """Wait for the lifecycle to be started. Will wait until the :meth:`HooksMixin.on_start` hooks have been called.
    Will return immediately if the lifecycle is already started."""
    await self.started_event.wait()

aiperf.common.mixins.aiperf_logger_mixin

AIPerfLoggerMixin

Bases: BaseMixin

Mixin to provide lazy evaluated logging for f-strings.

This mixin provides a logger with lazy evaluation support for f-strings, and direct log functions for all standard and custom logging levels.

see :class:AIPerfLogger for more details.

Usage

class MyClass(AIPerfLoggerMixin): def init(self): super().init() self.trace(lambda: f"Processing {item} of {count} ({item / count * 100}% complete)") self.info("Simple string message") self.debug(lambda i=i: f"Binding loop variable: {i}") self.warning("Warning message: %s", "legacy support") self.success("Benchmark completed successfully") self.notice("Warmup has completed") self.exception(f"Direct f-string usage: {e}")

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class AIPerfLoggerMixin(BaseMixin):
    """Mixin to provide lazy evaluated logging for f-strings.

    This mixin provides a logger with lazy evaluation support for f-strings,
    and direct log functions for all standard and custom logging levels.

    see :class:`AIPerfLogger` for more details.

    Usage:
        class MyClass(AIPerfLoggerMixin):
            def __init__(self):
                super().__init__()
                self.trace(lambda: f"Processing {item} of {count} ({item / count * 100}% complete)")
                self.info("Simple string message")
                self.debug(lambda i=i: f"Binding loop variable: {i}")
                self.warning("Warning message: %s", "legacy support")
                self.success("Benchmark completed successfully")
                self.notice("Warmup has completed")
                self.exception(f"Direct f-string usage: {e}")
    """

    def __init__(self, logger_name: str | None = None, **kwargs) -> None:
        super().__init__(**kwargs)
        self.logger = AIPerfLogger(logger_name or self.__class__.__name__)
        self._log = self.logger._log
        self.is_enabled_for = self.logger._logger.isEnabledFor

    def log(
        self, level: int, message: str | Callable[..., str], *args, **kwargs
    ) -> None:
        """Log a message at a specified level with lazy evaluation."""
        if self.is_enabled_for(level):
            self._log(level, message, *args, **kwargs)

    def trace(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a trace message with lazy evaluation."""
        if self.is_enabled_for(_TRACE):
            self._log(_TRACE, message, *args, **kwargs)

    def debug(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a debug message with lazy evaluation."""
        if self.is_enabled_for(_DEBUG):
            self._log(_DEBUG, message, *args, **kwargs)

    def info(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an info message with lazy evaluation."""
        if self.is_enabled_for(_INFO):
            self._log(_INFO, message, *args, **kwargs)

    def notice(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a notice message with lazy evaluation."""
        if self.is_enabled_for(_NOTICE):
            self._log(_NOTICE, message, *args, **kwargs)

    def warning(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a warning message with lazy evaluation."""
        if self.is_enabled_for(_WARNING):
            self._log(_WARNING, message, *args, **kwargs)

    def success(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a success message with lazy evaluation."""
        if self.is_enabled_for(_SUCCESS):
            self._log(_SUCCESS, message, *args, **kwargs)

    def error(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an error message with lazy evaluation."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, message, *args, **kwargs)

    def exception(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log an exception message with lazy evaluation."""
        if self.is_enabled_for(_ERROR):
            self._log(_ERROR, message, *args, exc_info=True, **kwargs)

    def critical(self, message: str | Callable[..., str], *args, **kwargs) -> None:
        """Log a critical message with lazy evaluation."""
        if self.is_enabled_for(_CRITICAL):
            self._log(_CRITICAL, message, *args, **kwargs)

critical(message, *args, **kwargs)

Log a critical message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
96
97
98
99
def critical(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a critical message with lazy evaluation."""
    if self.is_enabled_for(_CRITICAL):
        self._log(_CRITICAL, message, *args, **kwargs)

debug(message, *args, **kwargs)

Log a debug message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
61
62
63
64
def debug(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a debug message with lazy evaluation."""
    if self.is_enabled_for(_DEBUG):
        self._log(_DEBUG, message, *args, **kwargs)

error(message, *args, **kwargs)

Log an error message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
86
87
88
89
def error(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an error message with lazy evaluation."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, message, *args, **kwargs)

exception(message, *args, **kwargs)

Log an exception message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
91
92
93
94
def exception(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an exception message with lazy evaluation."""
    if self.is_enabled_for(_ERROR):
        self._log(_ERROR, message, *args, exc_info=True, **kwargs)

info(message, *args, **kwargs)

Log an info message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
66
67
68
69
def info(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log an info message with lazy evaluation."""
    if self.is_enabled_for(_INFO):
        self._log(_INFO, message, *args, **kwargs)

log(level, message, *args, **kwargs)

Log a message at a specified level with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
49
50
51
52
53
54
def log(
    self, level: int, message: str | Callable[..., str], *args, **kwargs
) -> None:
    """Log a message at a specified level with lazy evaluation."""
    if self.is_enabled_for(level):
        self._log(level, message, *args, **kwargs)

notice(message, *args, **kwargs)

Log a notice message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
71
72
73
74
def notice(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a notice message with lazy evaluation."""
    if self.is_enabled_for(_NOTICE):
        self._log(_NOTICE, message, *args, **kwargs)

success(message, *args, **kwargs)

Log a success message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
81
82
83
84
def success(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a success message with lazy evaluation."""
    if self.is_enabled_for(_SUCCESS):
        self._log(_SUCCESS, message, *args, **kwargs)

trace(message, *args, **kwargs)

Log a trace message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
56
57
58
59
def trace(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a trace message with lazy evaluation."""
    if self.is_enabled_for(_TRACE):
        self._log(_TRACE, message, *args, **kwargs)

warning(message, *args, **kwargs)

Log a warning message with lazy evaluation.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
76
77
78
79
def warning(self, message: str | Callable[..., str], *args, **kwargs) -> None:
    """Log a warning message with lazy evaluation."""
    if self.is_enabled_for(_WARNING):
        self._log(_WARNING, message, *args, **kwargs)

AIPerfLoggerProtocol

Bases: Protocol

Protocol to provide lazy evaluated logging for f-strings.

Source code in aiperf/common/mixins/aiperf_logger_mixin.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@runtime_checkable
class AIPerfLoggerProtocol(Protocol):
    """Protocol to provide lazy evaluated logging for f-strings."""

    def __init__(self, logger_name: str | None = None, **kwargs) -> None: ...
    def log(
        self, level: int, message: str | Callable[..., str], *args, **kwargs
    ) -> None: ...
    def trace(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def debug(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def info(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def notice(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def warning(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def success(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def error(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def exception(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def critical(self, message: str | Callable[..., str], *args, **kwargs) -> None: ...
    def is_enabled_for(self, level: int) -> bool: ...

aiperf.common.mixins.aiperf_profile_mixin

AIPerfProfileMixin

Bases: HooksMixin

Mixin to add profile-related hook support to a class.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@supports_hooks(
    AIPerfHook.ON_PROFILE_CONFIGURE,
    AIPerfHook.ON_PROFILE_START,
    AIPerfHook.ON_PROFILE_STOP,
)
class AIPerfProfileMixin(HooksMixin):
    """Mixin to add profile-related hook support to a class."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.logger = logging.getLogger(__class__.__name__)
        self.profile_started_event: asyncio.Event = asyncio.Event()
        self.profile_stopped_event: asyncio.Event = asyncio.Event()
        self.request_profile_stop_event: asyncio.Event = asyncio.Event()
        self.profile_configured_event: asyncio.Event = asyncio.Event()

    async def configure_profile(self, message: Message):
        """Configure the profile."""
        await self.run_hooks(AIPerfHook.ON_PROFILE_CONFIGURE, message)
        self.profile_configured_event.set()

    async def run_profile(self):
        """Run the profile."""
        # Run all the start hooks and set the start_event
        await self.run_hooks_async(AIPerfHook.ON_PROFILE_START)
        self.profile_started_event.set()

        while not self.request_profile_stop_event.is_set():
            try:
                # Wait forever until the stop_requested event is set
                await self.request_profile_stop_event.wait()
            except asyncio.CancelledError:
                break
            except Exception as e:
                self.logger.exception(
                    "Unhandled exception in while profile is running: %s", e
                )
                continue

        try:
            # Run all the stop hooks
            await self.run_hooks_async(AIPerfHook.ON_PROFILE_STOP)
        except Exception as e:
            self.logger.exception(
                "Unhandled exception in while profile is running: %s", e
            )

    async def stop_profile(self):
        """Request the profile to stop."""
        self.request_profile_stop_event.set()

    async def wait_for_profile_configured(self):
        """Wait for the profile to be configured."""
        await self.profile_configured_event.wait()

    async def wait_for_profile_started(self):
        """Wait for the profile to start."""
        await self.profile_started_event.wait()

    async def wait_for_profile_stopped(self):
        """Wait for the profile to stop."""
        await self.profile_stopped_event.wait()

configure_profile(message) async

Configure the profile.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
27
28
29
30
async def configure_profile(self, message: Message):
    """Configure the profile."""
    await self.run_hooks(AIPerfHook.ON_PROFILE_CONFIGURE, message)
    self.profile_configured_event.set()

run_profile() async

Run the profile.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
async def run_profile(self):
    """Run the profile."""
    # Run all the start hooks and set the start_event
    await self.run_hooks_async(AIPerfHook.ON_PROFILE_START)
    self.profile_started_event.set()

    while not self.request_profile_stop_event.is_set():
        try:
            # Wait forever until the stop_requested event is set
            await self.request_profile_stop_event.wait()
        except asyncio.CancelledError:
            break
        except Exception as e:
            self.logger.exception(
                "Unhandled exception in while profile is running: %s", e
            )
            continue

    try:
        # Run all the stop hooks
        await self.run_hooks_async(AIPerfHook.ON_PROFILE_STOP)
    except Exception as e:
        self.logger.exception(
            "Unhandled exception in while profile is running: %s", e
        )

stop_profile() async

Request the profile to stop.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
58
59
60
async def stop_profile(self):
    """Request the profile to stop."""
    self.request_profile_stop_event.set()

wait_for_profile_configured() async

Wait for the profile to be configured.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
62
63
64
async def wait_for_profile_configured(self):
    """Wait for the profile to be configured."""
    await self.profile_configured_event.wait()

wait_for_profile_started() async

Wait for the profile to start.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
66
67
68
async def wait_for_profile_started(self):
    """Wait for the profile to start."""
    await self.profile_started_event.wait()

wait_for_profile_stopped() async

Wait for the profile to stop.

Source code in aiperf/common/mixins/aiperf_profile_mixin.py
70
71
72
async def wait_for_profile_stopped(self):
    """Wait for the profile to stop."""
    await self.profile_stopped_event.wait()

aiperf.common.mixins.aiperf_task_mixin

AIPerfTaskMixin

Bases: HooksMixin, AsyncTaskManagerMixin

Mixin to add aiperf_task support to a class.

It hooks into the :meth:HooksMixin.on_init and :meth:HooksMixin.on_stop hooks to start and stop the tasks.

Source code in aiperf/common/mixins/aiperf_task_mixin.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@supports_hooks(
    AIPerfTaskHook.AIPERF_TASK,
    AIPerfHook.ON_INIT,
    AIPerfHook.ON_START,
    AIPerfHook.ON_STOP,
)
class AIPerfTaskMixin(HooksMixin, AsyncTaskManagerMixin):
    """Mixin to add aiperf_task support to a class.

    It hooks into the :meth:`HooksMixin.on_init` and :meth:`HooksMixin.on_stop` hooks to
    start and stop the tasks.
    """

    # TODO: This is somewhat deprecated in favor of the lifecycle mixin.

    # TODO: Once we create a Mixin for `self.stop_event`, we can avoid
    # having the user to call `while not self.stop_event.is_set()`

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def initialize(self) -> None:
        """Initialize the task."""
        await self.run_hooks(AIPerfHook.ON_INIT)

    async def start(self) -> None:
        """Start the task."""
        await self.run_hooks(AIPerfHook.ON_START)

    async def stop(self) -> None:
        """Stop the task."""
        await self.run_hooks(AIPerfHook.ON_STOP)

    # TODO: Should this be on_start?
    @on_init
    async def _start_tasks(self):
        """Start all the registered tasks in the background."""
        for hook in self.get_hooks(AIPerfTaskHook.AIPERF_TASK):
            if inspect.iscoroutinefunction(hook):
                self.execute_async(hook())
            else:
                self.execute_async(asyncio.to_thread(hook))

    @on_stop
    async def _stop_tasks(self):
        """Stop all the background tasks. This will wait for all the tasks to complete."""
        await self.cancel_all_tasks()

initialize() async

Initialize the task.

Source code in aiperf/common/mixins/aiperf_task_mixin.py
38
39
40
async def initialize(self) -> None:
    """Initialize the task."""
    await self.run_hooks(AIPerfHook.ON_INIT)

start() async

Start the task.

Source code in aiperf/common/mixins/aiperf_task_mixin.py
42
43
44
async def start(self) -> None:
    """Start the task."""
    await self.run_hooks(AIPerfHook.ON_START)

stop() async

Stop the task.

Source code in aiperf/common/mixins/aiperf_task_mixin.py
46
47
48
async def stop(self) -> None:
    """Stop the task."""
    await self.run_hooks(AIPerfHook.ON_STOP)

aiperf.common.mixins.async_task_manager_mixin

AsyncTaskManagerMixin

Bases: BaseMixin

Mixin to manage a set of async tasks.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class AsyncTaskManagerMixin(BaseMixin):
    """Mixin to manage a set of async tasks."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.tasks: set[asyncio.Task] = set()

    def execute_async(self, coro: Coroutine) -> asyncio.Task:
        """Create a task from a coroutine and add it to the set of tasks, and return immediately.
        The task will be automatically cleaned up when it completes.
        """
        task = asyncio.create_task(coro)
        self.tasks.add(task)
        task.add_done_callback(self.tasks.discard)
        return task

    async def wait_for_tasks(self) -> None:
        """Wait for all current tasks to complete."""
        await asyncio.gather(*list(self.tasks))

    async def cancel_all_tasks(
        self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
    ) -> None:
        """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

        Args:
            timeout: The timeout to wait for the tasks to complete.
        """
        if not self.tasks:
            return

        for task in list(self.tasks):
            task.cancel()

        with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
            await asyncio.wait_for(
                asyncio.gather(*self.tasks, return_exceptions=True), timeout=timeout
            )

        # Clear the tasks set after cancellation to avoid memory leaks
        self.tasks.clear()

cancel_all_tasks(timeout=TASK_CANCEL_TIMEOUT_SHORT) async

Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

Parameters:

Name Type Description Default
timeout float

The timeout to wait for the tasks to complete.

TASK_CANCEL_TIMEOUT_SHORT
Source code in aiperf/common/mixins/async_task_manager_mixin.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
async def cancel_all_tasks(
    self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
) -> None:
    """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

    Args:
        timeout: The timeout to wait for the tasks to complete.
    """
    if not self.tasks:
        return

    for task in list(self.tasks):
        task.cancel()

    with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
        await asyncio.wait_for(
            asyncio.gather(*self.tasks, return_exceptions=True), timeout=timeout
        )

    # Clear the tasks set after cancellation to avoid memory leaks
    self.tasks.clear()

execute_async(coro)

Create a task from a coroutine and add it to the set of tasks, and return immediately. The task will be automatically cleaned up when it completes.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
19
20
21
22
23
24
25
26
def execute_async(self, coro: Coroutine) -> asyncio.Task:
    """Create a task from a coroutine and add it to the set of tasks, and return immediately.
    The task will be automatically cleaned up when it completes.
    """
    task = asyncio.create_task(coro)
    self.tasks.add(task)
    task.add_done_callback(self.tasks.discard)
    return task

wait_for_tasks() async

Wait for all current tasks to complete.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
28
29
30
async def wait_for_tasks(self) -> None:
    """Wait for all current tasks to complete."""
    await asyncio.gather(*list(self.tasks))

AsyncTaskManagerProtocol

Bases: Protocol

Protocol to manage a set of async tasks.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@runtime_checkable
class AsyncTaskManagerProtocol(Protocol):
    """Protocol to manage a set of async tasks."""

    def execute_async(self, coro: Coroutine) -> asyncio.Task:
        """Create a task from a coroutine and add it to the set of tasks, and return immediately.
        The task will be automatically cleaned up when it completes.
        """
        ...

    async def stop(self) -> None:
        """Stop all tasks in the set and wait for them to complete."""

    async def cancel_all_tasks(
        self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
    ) -> None:
        """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

        Args:
            timeout: The timeout to wait for the tasks to complete.
        """

cancel_all_tasks(timeout=TASK_CANCEL_TIMEOUT_SHORT) async

Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

Parameters:

Name Type Description Default
timeout float

The timeout to wait for the tasks to complete.

TASK_CANCEL_TIMEOUT_SHORT
Source code in aiperf/common/mixins/async_task_manager_mixin.py
68
69
70
71
72
73
74
75
async def cancel_all_tasks(
    self, timeout: float = TASK_CANCEL_TIMEOUT_SHORT
) -> None:
    """Cancel all tasks in the set and wait for up to timeout seconds for them to complete.

    Args:
        timeout: The timeout to wait for the tasks to complete.
    """

execute_async(coro)

Create a task from a coroutine and add it to the set of tasks, and return immediately. The task will be automatically cleaned up when it completes.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
59
60
61
62
63
def execute_async(self, coro: Coroutine) -> asyncio.Task:
    """Create a task from a coroutine and add it to the set of tasks, and return immediately.
    The task will be automatically cleaned up when it completes.
    """
    ...

stop() async

Stop all tasks in the set and wait for them to complete.

Source code in aiperf/common/mixins/async_task_manager_mixin.py
65
66
async def stop(self) -> None:
    """Stop all tasks in the set and wait for them to complete."""

aiperf.common.mixins.base_mixin

BaseMixin

Base mixin class.

This Mixin creates a contract that Mixins should always pass **kwargs to super().init, regardless of whether they extend another mixin or not.

This will ensure that the BaseMixin is the last mixin to have its init method called, which means that all other mixins will have a proper chain of init methods with the correct arguments and no accidental broken inheritance.

Source code in aiperf/common/mixins/base_mixin.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class BaseMixin:
    """Base mixin class.

    This Mixin creates a contract that Mixins should always pass **kwargs to
    super().__init__, regardless of whether they extend another mixin or not.

    This will ensure that the BaseMixin is the last mixin to have its __init__
    method called, which means that all other mixins will have a proper
    chain of __init__ methods with the correct arguments and no accidental
    broken inheritance.
    """

    def __init__(self, **kwargs):
        # object.__init__ does not take any arguments
        super().__init__()

aiperf.common.mixins.hooks_mixin

HooksMixin

Bases: BaseMixin

Mixin to add hook support to a class. It abstracts away the details of the :class:HookSystem and provides a simple interface for registering and running hooks.

Source code in aiperf/common/mixins/hooks_mixin.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class HooksMixin(BaseMixin):
    """
    Mixin to add hook support to a class. It abstracts away the details of the
    :class:`HookSystem` and provides a simple interface for registering and running hooks.
    """

    # Class attributes that are set by the :func:`supports_hooks` decorator
    supported_hooks: ClassVar[set[HookType]] = set()

    def __init__(self, **kwargs):
        """
        Initialize the hook system and register all functions that are decorated with a hook decorator.
        """
        # Initialize the hook system
        self._hook_system = HookSystem(self.supported_hooks)

        # Register all functions that are decorated with a hook decorator
        # Iterate through MRO in reverse order to ensure base class hooks are registered first
        for cls in reversed(self.__class__.__mro__):
            # Skip object and other non-hook classes
            if not issubclass(cls, HooksMixin):
                continue

            # Get methods defined directly in this class (not inherited)
            for _, attr in cls.__dict__.items():
                if callable(attr) and hasattr(attr, AIPERF_HOOK_TYPE):
                    # Get the hook type from the function
                    hook_type = getattr(attr, AIPERF_HOOK_TYPE)
                    # Bind the method to the instance
                    bound_method = attr.__get__(self, cls)
                    # Register the function with the hook type
                    self.register_hook(hook_type, bound_method)

        super().__init__()

    def register_hook(self, hook_type: HookType, func: Callable):
        """Register a hook function for a given hook type.

        Args:
            hook_type: The hook type to register the function for.
            func: The function to register.
        """
        self._hook_system.register_hook(hook_type, func)

    async def run_hooks(self, hook_type: HookType, *args, **kwargs):
        """Run all the hooks serially. See :meth:`HookSystem.run_hooks`."""
        await self._hook_system.run_hooks(hook_type, *args, **kwargs)

    async def run_hooks_async(self, hook_type: HookType, *args, **kwargs):
        """Run all the hooks concurrently. See :meth:`HookSystem.run_hooks_async`."""
        await self._hook_system.run_hooks_async(hook_type, *args, **kwargs)

    def get_hooks(self, hook_type: HookType) -> list[Callable]:
        """Get all the registered hooks for the given hook type. See :meth:`HookSystem.get_hooks`."""
        return self._hook_system.get_hooks(hook_type)

__init__(**kwargs)

Initialize the hook system and register all functions that are decorated with a hook decorator.

Source code in aiperf/common/mixins/hooks_mixin.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(self, **kwargs):
    """
    Initialize the hook system and register all functions that are decorated with a hook decorator.
    """
    # Initialize the hook system
    self._hook_system = HookSystem(self.supported_hooks)

    # Register all functions that are decorated with a hook decorator
    # Iterate through MRO in reverse order to ensure base class hooks are registered first
    for cls in reversed(self.__class__.__mro__):
        # Skip object and other non-hook classes
        if not issubclass(cls, HooksMixin):
            continue

        # Get methods defined directly in this class (not inherited)
        for _, attr in cls.__dict__.items():
            if callable(attr) and hasattr(attr, AIPERF_HOOK_TYPE):
                # Get the hook type from the function
                hook_type = getattr(attr, AIPERF_HOOK_TYPE)
                # Bind the method to the instance
                bound_method = attr.__get__(self, cls)
                # Register the function with the hook type
                self.register_hook(hook_type, bound_method)

    super().__init__()

get_hooks(hook_type)

Get all the registered hooks for the given hook type. See :meth:HookSystem.get_hooks.

Source code in aiperf/common/mixins/hooks_mixin.py
62
63
64
def get_hooks(self, hook_type: HookType) -> list[Callable]:
    """Get all the registered hooks for the given hook type. See :meth:`HookSystem.get_hooks`."""
    return self._hook_system.get_hooks(hook_type)

register_hook(hook_type, func)

Register a hook function for a given hook type.

Parameters:

Name Type Description Default
hook_type HookType

The hook type to register the function for.

required
func Callable

The function to register.

required
Source code in aiperf/common/mixins/hooks_mixin.py
45
46
47
48
49
50
51
52
def register_hook(self, hook_type: HookType, func: Callable):
    """Register a hook function for a given hook type.

    Args:
        hook_type: The hook type to register the function for.
        func: The function to register.
    """
    self._hook_system.register_hook(hook_type, func)

run_hooks(hook_type, *args, **kwargs) async

Run all the hooks serially. See :meth:HookSystem.run_hooks.

Source code in aiperf/common/mixins/hooks_mixin.py
54
55
56
async def run_hooks(self, hook_type: HookType, *args, **kwargs):
    """Run all the hooks serially. See :meth:`HookSystem.run_hooks`."""
    await self._hook_system.run_hooks(hook_type, *args, **kwargs)

run_hooks_async(hook_type, *args, **kwargs) async

Run all the hooks concurrently. See :meth:HookSystem.run_hooks_async.

Source code in aiperf/common/mixins/hooks_mixin.py
58
59
60
async def run_hooks_async(self, hook_type: HookType, *args, **kwargs):
    """Run all the hooks concurrently. See :meth:`HookSystem.run_hooks_async`."""
    await self._hook_system.run_hooks_async(hook_type, *args, **kwargs)

aiperf.common.mixins.process_health_mixin

ProcessHealthMixin

Bases: BaseMixin

Mixin to provide process health information.

Source code in aiperf/common/mixins/process_health_mixin.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ProcessHealthMixin(BaseMixin):
    """Mixin to provide process health information."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Initialize process-specific CPU monitoring
        self.process: psutil.Process = psutil.Process()
        self.process.cpu_percent()  # throw away the first result (will be 0)
        self.create_time: float = self.process.create_time()

        self.process_health: ProcessHealth | None = None
        self.previous: ProcessHealth | None = None

    def get_process_health(self) -> ProcessHealth:
        """Get the process health information for the current process."""

        # Get process-specific CPU and memory usage
        raw_cpu_times = self.process.cpu_times()
        cpu_times = CPUTimes(
            user=raw_cpu_times[0],
            system=raw_cpu_times[1],
            iowait=raw_cpu_times[4] if len(raw_cpu_times) > 4 else 0.0,  # type: ignore
        )

        self.previous = self.process_health

        self.process_health = ProcessHealth(
            pid=self.process.pid,
            create_time=self.create_time,
            uptime=time.time() - self.create_time,
            cpu_usage=self.process.cpu_percent(),
            memory_usage=self.process.memory_info().rss / BYTES_PER_MIB,
            io_counters=self.process.io_counters(),
            cpu_times=cpu_times,
            num_ctx_switches=CtxSwitches(*self.process.num_ctx_switches()),
            num_threads=self.process.num_threads(),
        )
        return self.process_health

get_process_health()

Get the process health information for the current process.

Source code in aiperf/common/mixins/process_health_mixin.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def get_process_health(self) -> ProcessHealth:
    """Get the process health information for the current process."""

    # Get process-specific CPU and memory usage
    raw_cpu_times = self.process.cpu_times()
    cpu_times = CPUTimes(
        user=raw_cpu_times[0],
        system=raw_cpu_times[1],
        iowait=raw_cpu_times[4] if len(raw_cpu_times) > 4 else 0.0,  # type: ignore
    )

    self.previous = self.process_health

    self.process_health = ProcessHealth(
        pid=self.process.pid,
        create_time=self.create_time,
        uptime=time.time() - self.create_time,
        cpu_usage=self.process.cpu_percent(),
        memory_usage=self.process.memory_info().rss / BYTES_PER_MIB,
        io_counters=self.process.io_counters(),
        cpu_times=cpu_times,
        num_ctx_switches=CtxSwitches(*self.process.num_ctx_switches()),
        num_threads=self.process.num_threads(),
    )
    return self.process_health

aiperf.common.models.base_models

AIPerfBaseModel

Bases: BaseModel

Base model for all AIPerf Pydantic models. This class is configured to allow arbitrary types to be used as fields as to allow for more flexible model definitions by end users without breaking the existing code.

Source code in aiperf/common/models/base_models.py
10
11
12
13
14
15
16
class AIPerfBaseModel(BaseModel):
    """Base model for all AIPerf Pydantic models. This class is configured to allow
    arbitrary types to be used as fields as to allow for more flexible model definitions
    by end users without breaking the existing code.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

ExcludeIfNoneMixin

Bases: AIPerfBaseModel

Mixin to exclude fields from the serialized model if they are None.

The @exclude_if_none decorator can be used to specify which fields should be excluded from the serialized model if they are None.

Source code in aiperf/common/models/base_models.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class ExcludeIfNoneMixin(AIPerfBaseModel):
    """Mixin to exclude fields from the serialized model if they are None.

    The @exclude_if_none decorator can be used to specify which fields
    should be excluded from the serialized model if they are None.
    """

    _exclude_if_none_fields: ClassVar[set[str]] = set()
    """Set of field names that should be excluded from the serialized model if they
    are None. This is set by the @exclude_if_none decorator.
    """

    @model_serializer
    def _serialize_model(self) -> dict[str, Any]:
        """Serialize the model to a dictionary.

        This method overrides the default serializer to exclude fields that with a
        value of None and were marked with the @exclude_if_none decorator.
        """
        return {
            k: v
            for k, v in self
            if not (k in self._exclude_if_none_fields and v is None)
        }

exclude_if_none(field_names)

Decorator to set the _exclude_if_none_fields class attribute to the set of field names that should be excluded if they are None.

Source code in aiperf/common/models/base_models.py
19
20
21
22
23
24
25
26
27
28
29
30
def exclude_if_none(field_names: list[str]):
    """Decorator to set the _exclude_if_none_fields class attribute to the set of
    field names that should be excluded if they are None.
    """

    def decorator(model: type[BaseModelT]) -> type[BaseModelT]:
        if not hasattr(model, "_exclude_if_none_fields"):
            model._exclude_if_none_fields = set()
        model._exclude_if_none_fields.update(field_names)
        return model

    return decorator

aiperf.common.models.credit_models

CreditPhaseConfig

Bases: AIPerfBaseModel

Model for phase credit config. This is used by the TimingManager to configure the credit phases.

Source code in aiperf/common/models/credit_models.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class CreditPhaseConfig(AIPerfBaseModel):
    """Model for phase credit config. This is used by the TimingManager to configure the credit phases."""

    type: CreditPhase = Field(..., description="The type of credit phase")
    total_expected_requests: int | None = Field(
        default=None,
        ge=1,
        description="The total number of expected credits. If None, the phase is not request count based.",
    )
    expected_duration_sec: float | None = Field(
        default=None,
        ge=1,
        description="The expected duration of the credit phase in seconds. If None, the phase is not time based.",
    )

    @property
    def is_time_based(self) -> bool:
        return self.expected_duration_sec is not None

    @property
    def is_request_count_based(self) -> bool:
        return self.total_expected_requests is not None

    @property
    def is_valid(self) -> bool:
        """A phase config is valid if it is exactly one of the following:
        - is_time_based (expected_duration_sec is set and > 0)
        - is_request_count_based (total_expected_requests is set and > 0)
        """
        is_time_based = self.is_time_based
        is_request_count_based = self.is_request_count_based
        return (is_time_based and not is_request_count_based) or (
            not is_time_based and is_request_count_based
        )

is_valid property

A phase config is valid if it is exactly one of the following: - is_time_based (expected_duration_sec is set and > 0) - is_request_count_based (total_expected_requests is set and > 0)

CreditPhaseStats

Bases: CreditPhaseConfig

Model for phase credit stats. Extends the CreditPhaseConfig fields to track the progress of the credit phases. How many credits were dropped and how many were returned, as well as the progress percentage of the phase.

Source code in aiperf/common/models/credit_models.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class CreditPhaseStats(CreditPhaseConfig):
    """Model for phase credit stats. Extends the CreditPhaseConfig fields to track the progress of the credit phases.
    How many credits were dropped and how many were returned, as well as the progress percentage of the phase."""

    start_ns: int | None = Field(
        default=None,
        description="The start time of the credit phase in nanoseconds.",
    )
    sent_end_ns: int | None = Field(
        default=None,
        description="The time of the last sent credit in nanoseconds. If None, the phase has not sent all credits.",
    )
    end_ns: int | None = Field(
        default=None,
        ge=1,
        description="The time in which the last credit was returned from the workers in nanoseconds. If None, the phase has not completed.",
    )
    sent: int = Field(default=0, description="The number of sent credits")
    completed: int = Field(
        default=0,
        description="The number of completed credits (returned from the workers)",
    )

    @property
    def is_sending_complete(self) -> bool:
        return self.sent_end_ns is not None

    @property
    def is_complete(self) -> bool:
        return self.is_sending_complete and self.end_ns is not None

    @property
    def is_started(self) -> bool:
        return self.start_ns is not None

    @property
    def in_flight(self) -> int:
        """Calculate the number of in-flight credits (sent but not completed)."""
        return self.sent - self.completed

    @property
    def should_send(self) -> bool:
        """Whether the phase should send more credits."""
        if self.is_time_based:
            return (
                time.time_ns() - (self.start_ns or 0)
                <= (self.expected_duration_sec * NANOS_PER_SECOND)  # type: ignore
            )
        elif self.is_request_count_based:
            return self.sent < self.total_expected_requests  # type: ignore
        raise InvalidStateError("Phase is not time or request count based")

    @property
    def progress_percent(self) -> float | None:
        if self.start_ns is None:
            return None

        if self.is_complete:
            return 100

        if self.is_time_based:
            # Time based, so progress is the percentage of time elapsed compared to the duration

            return (
                (time.time_ns() - self.start_ns)
                / (self.expected_duration_sec * NANOS_PER_SECOND)  # type: ignore
            ) * 100

        elif self.total_expected_requests is not None:
            # Credit count based, so progress is the percentage of credits returned
            return (self.completed / self.total_expected_requests) * 100

        # We don't know the progress
        return None

    @classmethod
    def from_phase_config(cls, phase_config: CreditPhaseConfig) -> "CreditPhaseStats":
        """Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase."""
        return cls(
            type=phase_config.type,
            total_expected_requests=phase_config.total_expected_requests,
            expected_duration_sec=phase_config.expected_duration_sec,
        )

in_flight property

Calculate the number of in-flight credits (sent but not completed).

should_send property

Whether the phase should send more credits.

from_phase_config(phase_config) classmethod

Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase.

Source code in aiperf/common/models/credit_models.py
125
126
127
128
129
130
131
132
@classmethod
def from_phase_config(cls, phase_config: CreditPhaseConfig) -> "CreditPhaseStats":
    """Create a CreditPhaseStats from a CreditPhaseConfig. This is used to initialize the stats for a phase."""
    return cls(
        type=phase_config.type,
        total_expected_requests=phase_config.total_expected_requests,
        expected_duration_sec=phase_config.expected_duration_sec,
    )

PhaseProcessingStats

Bases: AIPerfBaseModel

Model for phase processing stats. How many requests were processed and how many errors were encountered.

Source code in aiperf/common/models/credit_models.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class PhaseProcessingStats(AIPerfBaseModel):
    """Model for phase processing stats. How many requests were processed and
    how many errors were encountered."""

    processed: int = Field(
        default=0, description="The number of records processed successfully"
    )
    errors: int = Field(
        default=0, description="The number of record errors encountered"
    )
    total_expected_requests: int | None = Field(
        default=None,
        description="The total number of expected requests to process. If None, the phase is not request count based.",
    )

    @property
    def total_records(self) -> int:
        """The total number of records processed successfully or in error."""
        return self.processed + self.errors

total_records property

The total number of records processed successfully or in error.

aiperf.common.models.dataset_models

Conversation

Bases: AIPerfBaseModel

A dataset representation of a full conversation.

A conversation is a sequence of turns between a user and an endpoint, and it contains the session ID and all the turns that consists the conversation.

Source code in aiperf/common/models/dataset_models.py
63
64
65
66
67
68
69
70
71
72
73
class Conversation(AIPerfBaseModel):
    """A dataset representation of a full conversation.

    A conversation is a sequence of turns between a user and an endpoint,
    and it contains the session ID and all the turns that consists the conversation.
    """

    turns: list[Turn] = Field(
        default=[], description="List of turns in the conversation."
    )
    session_id: str = Field(default="", description="Session ID of the conversation.")

Turn

Bases: AIPerfBaseModel

A dataset representation of a single turn within a conversation.

A turn is a single interaction between a user and an AI assistant, and it contains timestamp, delay, and raw data that user sends in each turn.

Source code in aiperf/common/models/dataset_models.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@exclude_if_none(["role"])
class Turn(AIPerfBaseModel):
    """A dataset representation of a single turn within a conversation.

    A turn is a single interaction between a user and an AI assistant,
    and it contains timestamp, delay, and raw data that user sends in each turn.
    """

    timestamp: int | None = Field(
        default=None, description="Timestamp of the turn in milliseconds."
    )
    delay: int | None = Field(
        default=None,
        description="Amount of milliseconds to wait before sending the turn.",
    )
    role: str | None = Field(default=None, description="Role of the turn.")
    texts: list[Text] = Field(
        default=[], description="Collection of text data in each turn."
    )
    images: list[Image] = Field(
        default=[], description="Collection of image data in each turn."
    )
    audios: list[Audio] = Field(
        default=[], description="Collection of audio data in each turn."
    )

aiperf.common.models.error_models

ErrorDetails

Bases: AIPerfBaseModel

Encapsulates details about an error.

Source code in aiperf/common/models/error_models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ErrorDetails(AIPerfBaseModel):
    """Encapsulates details about an error."""

    code: int | None = Field(
        default=None,
        description="The error code.",
    )
    type: str | None = Field(
        default=None,
        description="The error type.",
    )
    message: str = Field(
        ...,
        description="The error message.",
    )

    def __eq__(self, other: Any) -> bool:
        """Check if the error details are equal by comparing the code, type, and message."""
        if not isinstance(other, ErrorDetails):
            return False
        return (
            self.code == other.code
            and self.type == other.type
            and self.message == other.message
        )

    def __hash__(self) -> int:
        """Hash the error details by hashing the code, type, and message."""
        return hash((self.code, self.type, self.message))

    @classmethod
    def from_exception(cls, e: Exception) -> "ErrorDetails":
        """Create an error details object from an exception."""
        return cls(
            type=e.__class__.__name__,
            message=str(e),
        )

__eq__(other)

Check if the error details are equal by comparing the code, type, and message.

Source code in aiperf/common/models/error_models.py
26
27
28
29
30
31
32
33
34
def __eq__(self, other: Any) -> bool:
    """Check if the error details are equal by comparing the code, type, and message."""
    if not isinstance(other, ErrorDetails):
        return False
    return (
        self.code == other.code
        and self.type == other.type
        and self.message == other.message
    )

__hash__()

Hash the error details by hashing the code, type, and message.

Source code in aiperf/common/models/error_models.py
36
37
38
def __hash__(self) -> int:
    """Hash the error details by hashing the code, type, and message."""
    return hash((self.code, self.type, self.message))

from_exception(e) classmethod

Create an error details object from an exception.

Source code in aiperf/common/models/error_models.py
40
41
42
43
44
45
46
@classmethod
def from_exception(cls, e: Exception) -> "ErrorDetails":
    """Create an error details object from an exception."""
    return cls(
        type=e.__class__.__name__,
        message=str(e),
    )

ErrorDetailsCount

Bases: AIPerfBaseModel

Count of error details.

Source code in aiperf/common/models/error_models.py
49
50
51
52
53
54
55
56
class ErrorDetailsCount(AIPerfBaseModel):
    """Count of error details."""

    error_details: ErrorDetails
    count: int = Field(
        ...,
        description="The count of the error details.",
    )

aiperf.common.models.health_models

ProcessHealth

Bases: AIPerfBaseModel

Model for process health data.

Source code in aiperf/common/models/health_models.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ProcessHealth(AIPerfBaseModel):
    """Model for process health data."""

    pid: int | None = Field(
        default=None,
        description="The PID of the process",
    )
    create_time: float = Field(
        ..., description="The creation time of the process in seconds"
    )
    uptime: float = Field(..., description="The uptime of the process in seconds")
    cpu_usage: float = Field(
        ..., description="The current CPU usage of the process in %"
    )
    memory_usage: float = Field(
        ..., description="The current memory usage of the process in MiB (rss)"
    )
    io_counters: IOCounters | tuple | None = Field(
        default=None,
        description="The current I/O counters of the process (read_count, write_count, read_bytes, write_bytes, read_chars, write_chars)",
    )
    cpu_times: CPUTimes | tuple | None = Field(
        default=None,
        description="The current CPU times of the process (user, system, iowait)",
    )
    num_ctx_switches: CtxSwitches | tuple | None = Field(
        default=None,
        description="The current number of context switches (voluntary, involuntary)",
    )
    num_threads: int | None = Field(
        default=None,
        description="The current number of threads",
    )

aiperf.common.models.record_models

InferenceServerResponse

Bases: AIPerfBaseModel

Response from a inference client.

Source code in aiperf/common/models/record_models.py
52
53
54
55
56
57
58
class InferenceServerResponse(AIPerfBaseModel):
    """Response from a inference client."""

    perf_ns: int = Field(
        ...,
        description="The timestamp of the response in nanoseconds (perf_counter_ns).",
    )

MetricResult

Bases: AIPerfBaseModel

The result values of a single metric.

Source code in aiperf/common/models/record_models.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MetricResult(AIPerfBaseModel):
    """The result values of a single metric."""

    tag: str = Field(description="The unique identifier of the metric")
    unit: str = Field(description="The unit of the metric, e.g. 'ms'")
    header: str = Field(
        description="The user friendly name of the metric (e.g. 'Inter Token Latency')"
    )
    avg: float | None = None
    min: float | None = None
    max: float | None = None
    p1: float | None = None
    p5: float | None = None
    p25: float | None = None
    p50: float | None = None
    p75: float | None = None
    p90: float | None = None
    p95: float | None = None
    p99: float | None = None
    std: float | None = None
    count: int | None = Field(
        default=None,
        description="The total number of records used to calculate the metric",
    )
    streaming_only: bool = Field(
        default=False,
        description="Whether the metric only applies when streaming is enabled",
    )

ParsedResponseRecord

Bases: AIPerfBaseModel

Record of a request and its associated responses, already parsed and ready for metrics.

Source code in aiperf/common/models/record_models.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class ParsedResponseRecord(AIPerfBaseModel):
    """Record of a request and its associated responses, already parsed and ready for metrics."""

    worker_id: str = Field(
        description="The ID of the worker that processed the request."
    )
    request: RequestRecord = Field(description="The original request record")
    responses: list[ResponseData] = Field(description="The parsed response data.")
    input_token_count: int | None = Field(
        default=None,
        description="The number of tokens in the input. If None, the number of tokens could not be calculated.",
    )
    output_token_count: int | None = Field(
        default=None,
        description="The number of tokens across all responses. If None, the number of tokens could not be calculated.",
    )

    @cached_property
    def start_perf_ns(self) -> int:
        """Get the start time of the request in nanoseconds (perf_counter_ns)."""
        return self.request.start_perf_ns

    @cached_property
    def timestamp_ns(self) -> int:
        """Get the wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns)."""
        return self.request.timestamp_ns

    # TODO: How do we differentiate the end of the request vs the time of the last response?
    #       Which one should we use for the latency metrics?
    @cached_property
    def end_perf_ns(self) -> int:
        """Get the end time of the request in nanoseconds (perf_counter_ns).
        If request.end_perf_ns is not set, use the time of the last response.
        If there are no responses, use sys.maxsize.
        """
        return (
            self.request.end_perf_ns
            if self.request.end_perf_ns
            else self.responses[-1].perf_ns
            if self.responses
            else sys.maxsize
        )

    @cached_property
    def request_duration_ns(self) -> int:
        """Get the duration of the request in nanoseconds."""
        return self.end_perf_ns - self.start_perf_ns

    @cached_property
    def tokens_per_second(self) -> float | None:
        """Get the number of tokens per second of the request."""
        if self.output_token_count is None or self.request_duration_ns == 0:
            return None
        return self.output_token_count / (self.request_duration_ns / NANOS_PER_SECOND)

    @cached_property
    def has_error(self) -> bool:
        """Check if the response record has an error."""
        return self.request.has_error

    @cached_property
    def valid(self) -> bool:
        """Check if the response record is valid.

        Checks:
        - Request has no errors
        - Has at least one response
        - Start time is before the end time
        - Response timestamps are within valid ranges

        Returns:
            bool: True if the record is valid, False otherwise.
        """
        return (
            not self.has_error
            and len(self.responses) > 0
            and 0 <= self.start_perf_ns < self.end_perf_ns < sys.maxsize
            and all(0 < response.perf_ns < sys.maxsize for response in self.responses)
        )

end_perf_ns cached property

Get the end time of the request in nanoseconds (perf_counter_ns). If request.end_perf_ns is not set, use the time of the last response. If there are no responses, use sys.maxsize.

has_error cached property

Check if the response record has an error.

request_duration_ns cached property

Get the duration of the request in nanoseconds.

start_perf_ns cached property

Get the start time of the request in nanoseconds (perf_counter_ns).

timestamp_ns cached property

Get the wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns).

tokens_per_second cached property

Get the number of tokens per second of the request.

valid cached property

Check if the response record is valid.

Checks: - Request has no errors - Has at least one response - Start time is before the end time - Response timestamps are within valid ranges

Returns:

Name Type Description
bool bool

True if the record is valid, False otherwise.

RequestRecord

Bases: AIPerfBaseModel

Record of a request with its associated responses.

Source code in aiperf/common/models/record_models.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class RequestRecord(AIPerfBaseModel):
    """Record of a request with its associated responses."""

    request: Any | None = Field(
        default=None,
        description="The request payload formatted for the inference API.",
    )
    conversation_id: str | None = Field(
        default=None,
        description="The ID of the conversation (if applicable).",
    )
    turn_index: int | None = Field(
        default=None,
        ge=0,
        description="The index of the turn in the conversation (if applicable).",
    )
    model_name: str | None = Field(
        default=None,
        description="The name of the model targeted by the request.",
    )
    timestamp_ns: int = Field(
        default_factory=time.time_ns,
        description="The wall clock timestamp of the request in nanoseconds. DO NOT USE FOR LATENCY CALCULATIONS. (time.time_ns).",
    )
    start_perf_ns: int = Field(
        default_factory=time.perf_counter_ns,
        description="The start reference time of the request in nanoseconds used for latency calculations (perf_counter_ns).",
    )
    end_perf_ns: int | None = Field(
        default=None,
        description="The end time of the request in nanoseconds (perf_counter_ns).",
    )
    recv_start_perf_ns: int | None = Field(
        default=None,
        description="The start time of the streaming response in nanoseconds (perf_counter_ns).",
    )
    status: int | None = Field(
        default=None,
        description="The HTTP status code of the response.",
    )
    # NOTE: We need to use SerializeAsAny to allow for generic subclass support
    # NOTE: The order of the types is important, as that is the order they are type checked.
    #       Start with the most specific types and work towards the most general types.
    responses: SerializeAsAny[
        list[SSEMessage | TextResponse | InferenceServerResponse | Any]
    ] = Field(
        default_factory=list,
        description="The raw responses received from the request.",
    )
    error: ErrorDetails | None = Field(
        default=None,
        description="The error details if the request failed.",
    )
    delayed_ns: int | None = Field(
        default=None,
        ge=0,
        description="The number of nanoseconds the request was delayed from when it was expected to be sent, "
        "or None if the request was sent on time, or did not have a credit_drop_ns timestamp.",
    )
    credit_phase: CreditPhase = Field(
        default=CreditPhase.PROFILING,
        description="The type of credit phase (either warmup or profiling)",
    )

    @property
    def delayed(self) -> bool:
        """Check if the request was delayed."""
        return self.delayed_ns is not None and self.delayed_ns > 0

    # TODO: Most of these properties will be removed once we have proper record handling and metrics.

    @property
    def has_error(self) -> bool:
        """Check if the request record has an error."""
        return self.error is not None

    @property
    def valid(self) -> bool:
        """Check if the request record is valid by ensuring that the start time
        and response timestamps are within valid ranges.

        Returns:
            bool: True if the record is valid, False otherwise.
        """
        return not self.has_error and (
            0 <= self.start_perf_ns < sys.maxsize
            and len(self.responses) > 0
            and all(0 < response.perf_ns < sys.maxsize for response in self.responses)
        )

    @property
    def time_to_first_response_ns(self) -> int | None:
        """Get the time to the first response in nanoseconds."""
        if not self.valid:
            return None
        return (
            self.responses[0].perf_ns - self.start_perf_ns
            if self.start_perf_ns
            else None
        )

    @property
    def time_to_second_response_ns(self) -> int | None:
        """Get the time to the second response in nanoseconds."""
        if not self.valid or len(self.responses) < 2:
            return None
        return (
            self.responses[1].perf_ns - self.responses[0].perf_ns
            if self.responses[1].perf_ns and self.responses[0].perf_ns
            else None
        )

    @property
    def time_to_last_response_ns(self) -> int | None:
        """Get the time to the last response in nanoseconds."""
        if not self.valid:
            return None
        if self.end_perf_ns is None or self.start_perf_ns is None:
            return None
        return self.end_perf_ns - self.start_perf_ns if self.start_perf_ns else None

    @property
    def inter_token_latency_ns(self) -> float | None:
        """Get the interval between responses in nanoseconds."""
        if not self.valid or len(self.responses) < 2:
            return None

        if (
            isinstance(self.responses[-1], SSEMessage)
            and self.responses[-1].packets[-1].value == "[DONE]"
        ):
            return (
                (self.responses[-2].perf_ns - self.responses[0].perf_ns)
                / (len(self.responses) - 2)
                if self.responses[-2].perf_ns and self.responses[0].perf_ns
                else None
            )

        return (
            (self.responses[-1].perf_ns - self.responses[0].perf_ns)
            / (len(self.responses) - 1)
            if self.responses[-1].perf_ns and self.responses[0].perf_ns
            else None
        )

    def token_latency_ns(self, index: int) -> float | None:
        """Get the latency of a token in nanoseconds."""
        if not self.valid or len(self.responses) < 1:
            return None
        if index == 0:
            return (
                self.responses[0].perf_ns - self.recv_start_perf_ns
                if self.recv_start_perf_ns
                else None
            )
        return (
            self.responses[index].perf_ns - self.responses[index - 1].perf_ns
            if self.responses[index].perf_ns and self.responses[index - 1].perf_ns
            else None
        )

delayed property

Check if the request was delayed.

has_error property

Check if the request record has an error.

inter_token_latency_ns property

Get the interval between responses in nanoseconds.

time_to_first_response_ns property

Get the time to the first response in nanoseconds.

time_to_last_response_ns property

Get the time to the last response in nanoseconds.

time_to_second_response_ns property

Get the time to the second response in nanoseconds.

valid property

Check if the request record is valid by ensuring that the start time and response timestamps are within valid ranges.

Returns:

Name Type Description
bool bool

True if the record is valid, False otherwise.

token_latency_ns(index)

Get the latency of a token in nanoseconds.

Source code in aiperf/common/models/record_models.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def token_latency_ns(self, index: int) -> float | None:
    """Get the latency of a token in nanoseconds."""
    if not self.valid or len(self.responses) < 1:
        return None
    if index == 0:
        return (
            self.responses[0].perf_ns - self.recv_start_perf_ns
            if self.recv_start_perf_ns
            else None
        )
    return (
        self.responses[index].perf_ns - self.responses[index - 1].perf_ns
        if self.responses[index].perf_ns and self.responses[index - 1].perf_ns
        else None
    )

ResponseData

Bases: AIPerfBaseModel

Base class for all response data.

Source code in aiperf/common/models/record_models.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class ResponseData(AIPerfBaseModel):
    """Base class for all response data."""

    perf_ns: int = Field(description="The performance timestamp of the response.")
    raw_text: list[str] = Field(description="The raw text of the response.")
    parsed_text: list[str | None] = Field(
        description="The parsed text of the response."
    )
    token_count: int | None = Field(
        default=None,
        description="The total number of tokens in the response from the parsed text.",
    )
    metadata: dict[str, Any] = Field(
        default_factory=dict, description="The metadata of the response."
    )

SSEField

Bases: AIPerfBaseModel

Base model for a single field in an SSE message.

Source code in aiperf/common/models/record_models.py
74
75
76
77
78
79
80
81
82
83
84
class SSEField(AIPerfBaseModel):
    """Base model for a single field in an SSE message."""

    name: SSEFieldType | str = Field(
        ...,
        description="The name of the field. e.g. 'data', 'event', 'id', 'retry', 'comment'.",
    )
    value: str | None = Field(
        default=None,
        description="The value of the field.",
    )

SSEMessage

Bases: InferenceServerResponse

Individual SSE message from an SSE stream. Delimited by

.

Source code in aiperf/common/models/record_models.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class SSEMessage(InferenceServerResponse):
    """Individual SSE message from an SSE stream. Delimited by \n\n."""

    # Note: "fields" is a restricted keyword in pydantic
    packets: list[SSEField] = Field(
        default_factory=list,
        description="The fields contained in the message.",
    )

    def extract_data_content(self) -> list[str]:
        """Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies
        that each data content should be combined and delimited by a single \n. We have left
        it as a list to allow the caller to decide how to handle the data.

        Returns:
            list[str]: A list of strings containing the data contents of the SSE message.
        """
        return [
            packet.value
            for packet in self.packets
            if packet.name == SSEFieldType.DATA and packet.value is not None
        ]

extract_data_content()

Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies that each data content should be combined and delimited by a single . We have left it as a list to allow the caller to decide how to handle the data.

    Returns:
        list[str]: A list of strings containing the data contents of the SSE message.
Source code in aiperf/common/models/record_models.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def extract_data_content(self) -> list[str]:
    """Extract the data contents from the SSE message as a list of strings. Note that the SSE spec specifies
    that each data content should be combined and delimited by a single \n. We have left
    it as a list to allow the caller to decide how to handle the data.

    Returns:
        list[str]: A list of strings containing the data contents of the SSE message.
    """
    return [
        packet.value
        for packet in self.packets
        if packet.name == SSEFieldType.DATA and packet.value is not None
    ]

TextResponse

Bases: InferenceServerResponse

Raw text response from a inference client including an optional content type.

Source code in aiperf/common/models/record_models.py
61
62
63
64
65
66
67
68
69
70
71
class TextResponse(InferenceServerResponse):
    """Raw text response from a inference client including an optional content type."""

    content_type: str | None = Field(
        default=None,
        description="The content type of the response. e.g. 'text/plain', 'application/json'.",
    )
    text: str = Field(
        ...,
        description="The text of the response.",
    )

aiperf.common.models.service_models

ServiceRunInfo

Bases: BaseModel

Base model for tracking service run information.

Source code in aiperf/common/models/service_models.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ServiceRunInfo(BaseModel):
    """Base model for tracking service run information."""

    service_type: ServiceType = Field(
        ...,
        description="The type of service",
    )
    registration_status: ServiceRegistrationStatus = Field(
        ...,
        description="The registration status of the service",
    )
    service_id: str = Field(
        ...,
        description="The ID of the service",
    )
    first_seen: int | None = Field(
        default_factory=time.time_ns,
        description="The first time the service was seen",
    )
    last_seen: int | None = Field(
        default_factory=time.time_ns,
        description="The last time the service was seen",
    )
    state: ServiceState = Field(
        default=ServiceState.UNKNOWN,
        description="The current state of the service",
    )

aiperf.common.models.worker_models

WorkerPhaseTaskStats

Bases: AIPerfBaseModel

Stats for the tasks that have been sent to the worker for a given credit phase.

Source code in aiperf/common/models/worker_models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class WorkerPhaseTaskStats(AIPerfBaseModel):
    """Stats for the tasks that have been sent to the worker for a given credit phase."""

    total: int = Field(
        default=0,
        description="The total number of tasks that have been sent to the worker. "
        "Not all tasks will be completed.",
    )
    failed: int = Field(
        default=0,
        description="The number of tasks that returned an error",
    )
    completed: int = Field(
        default=0,
        description="The number of tasks that were completed successfully",
    )

    @property
    def in_progress(self) -> int:
        """The number of tasks that are currently in progress.

        This is the total number of tasks sent to the worker minus the number of failed and successfully completed tasks.
        """
        return self.total - self.completed - self.failed

in_progress property

The number of tasks that are currently in progress.

This is the total number of tasks sent to the worker minus the number of failed and successfully completed tasks.

aiperf.common.service.base_component_service

BaseComponentService

Bases: BaseService

Base class for all Component services.

This class provides a common interface for all Component services in the AIPerf framework such as the Timing Manager, Dataset Manager, etc.

It extends the BaseService by: - Subscribing to the command message_type - Processing command messages - Sending registration requests to the system controller - Sending heartbeat notifications to the system controller - Sending status notifications to the system controller - Helpers to create heartbeat, registration, and status messages - Request the appropriate communication clients for a component service

Source code in aiperf/common/service/base_component_service.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class BaseComponentService(BaseService):
    """Base class for all Component services.

    This class provides a common interface for all Component services in the AIPerf
    framework such as the Timing Manager, Dataset Manager, etc.

    It extends the BaseService by:
    - Subscribing to the command message_type
    - Processing command messages
    - Sending registration requests to the system controller
    - Sending heartbeat notifications to the system controller
    - Sending status notifications to the system controller
    - Helpers to create heartbeat, registration, and status messages
    - Request the appropriate communication clients for a component service
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
        **kwargs,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

        self._command_callbacks: dict[
            CommandType, Callable[[CommandMessage], Awaitable[None]]
        ] = {}
        self._heartbeat_interval_seconds = (
            self.service_config.heartbeat_interval_seconds
        )

    @on_init
    async def _on_init(self) -> None:
        """Automatically subscribe to the command message_type and register the service
        with the system controller when the run hook is called.

        This method will:
        - Subscribe to the command message_type
        - Wait for the communication to be fully initialized
        - Register the service with the system controller
        """
        # Subscribe to the command message_type
        try:
            await self.sub_client.subscribe(
                MessageType.COMMAND,
                self.process_command_message,
            )
        except Exception as e:
            raise self._service_error("Failed to subscribe to command topic") from e

        # Register the service
        try:
            await self.register()
        except Exception as e:
            raise self._service_error("Failed to register service") from e

    @aiperf_task
    async def _heartbeat_task(self) -> None:
        """Starts a background task to send heartbeats at regular intervals. It
        will continue to send heartbeats even if an error occurs until the stop
        event is set.
        """
        while not self.stop_event.is_set():
            # Sleep first to avoid sending a heartbeat before the registration
            # message has been published
            await asyncio.sleep(self._heartbeat_interval_seconds)

            try:
                await self.send_heartbeat()
            except Exception as e:
                self.logger.error("Exception sending heartbeat: %s", e)
                # continue to keep sending heartbeats regardless of the error

        self.logger.debug("Heartbeat task stopped")

    async def send_heartbeat(self) -> None:
        """Send a heartbeat notification to the system controller."""
        heartbeat_message = self.create_heartbeat_message()
        self.logger.debug("Sending heartbeat: %s", heartbeat_message)
        try:
            await self.pub_client.publish(
                message=heartbeat_message,
            )
        except Exception as e:
            raise self._service_error("Failed to send heartbeat") from e

    async def register(self) -> None:
        """Publish a registration request to the system controller.

        This method should be called after the service has been initialized and is
        ready to start processing messages.
        """
        self.logger.info(
            "Attempting to register service %s (%s) with system controller",
            self.service_type,
            self.service_id,
        )
        try:
            await self.pub_client.publish(
                message=self.create_registration_message(),
            )
        except Exception as e:
            raise self._service_error("Failed to register service") from e

    async def process_command_message(self, message: CommandMessage) -> None:
        """Process a command message received from the controller.

        This method will process the command message and execute the appropriate action.
        """
        if message.target_service_id and message.target_service_id != self.service_id:
            return  # Ignore commands meant for other services
        if (
            message.target_service_type
            and message.target_service_type != self.service_type
        ):
            return  # Ignore commands meant for other services

        self.logger.debug(
            "%s: Processing command message: %s", self.service_id, message
        )
        cmd = message.command
        response_data = None
        try:
            if cmd == CommandType.PROFILE_START:
                response_data = await self.start()

            elif cmd == CommandType.SHUTDOWN:
                self.logger.debug("%s: Received stop command", self.service_id)
                self.stop_event.set()

            elif cmd == CommandType.PROFILE_CONFIGURE:
                await self.run_hooks(AIPerfHook.ON_CONFIGURE, message)

            elif cmd in self._command_callbacks:
                response_data = await self._command_callbacks[cmd](message)

            else:
                raise self._service_error(
                    f"Received unknown command: {cmd}",
                )

            # Publish the success response
            await self.pub_client.publish(
                CommandResponseMessage(
                    service_id=self.service_id,
                    command=cmd,
                    command_id=message.command_id,
                    status=CommandResponseStatus.SUCCESS,
                    data=response_data,
                ),
            )

        except Exception as e:
            # Publish the failure response
            await self.pub_client.publish(
                CommandResponseMessage(
                    service_id=self.service_id,
                    command=cmd,
                    command_id=message.command_id,
                    status=CommandResponseStatus.FAILURE,
                    error=ErrorDetails.from_exception(e),
                ),
            )

    def register_command_callback(
        self,
        cmd: CommandType,
        callback: Callable[[CommandMessage], Awaitable[None]],
    ) -> None:
        """Register a single callback for a command."""
        self._command_callbacks[cmd] = callback

    @on_set_state
    async def _on_set_state(self, state: ServiceState) -> None:
        """Action to take when the service state is set.

        This method will also publish the status message to the status message_type if the
        communications are initialized.
        """
        if (
            self.pub_client
            and self.pub_client.is_initialized
            and not self.pub_client.stop_event.is_set()
        ):
            await self.pub_client.publish(
                self.create_status_message(state),
            )

    def create_heartbeat_message(self) -> HeartbeatMessage:
        """Create a heartbeat notification message."""
        return HeartbeatMessage(
            service_id=self.service_id,
            service_type=self.service_type,
            state=self.state,
        )

    def create_registration_message(self) -> RegistrationMessage:
        """Create a registration request message."""
        return RegistrationMessage(
            service_id=self.service_id,
            service_type=self.service_type,
        )

    def create_status_message(self, state: ServiceState) -> StatusMessage:
        """Create a status notification message."""
        return StatusMessage(
            service_id=self.service_id,
            state=state,
            service_type=self.service_type,
        )

create_heartbeat_message()

Create a heartbeat notification message.

Source code in aiperf/common/service/base_component_service.py
219
220
221
222
223
224
225
def create_heartbeat_message(self) -> HeartbeatMessage:
    """Create a heartbeat notification message."""
    return HeartbeatMessage(
        service_id=self.service_id,
        service_type=self.service_type,
        state=self.state,
    )

create_registration_message()

Create a registration request message.

Source code in aiperf/common/service/base_component_service.py
227
228
229
230
231
232
def create_registration_message(self) -> RegistrationMessage:
    """Create a registration request message."""
    return RegistrationMessage(
        service_id=self.service_id,
        service_type=self.service_type,
    )

create_status_message(state)

Create a status notification message.

Source code in aiperf/common/service/base_component_service.py
234
235
236
237
238
239
240
def create_status_message(self, state: ServiceState) -> StatusMessage:
    """Create a status notification message."""
    return StatusMessage(
        service_id=self.service_id,
        state=state,
        service_type=self.service_type,
    )

process_command_message(message) async

Process a command message received from the controller.

This method will process the command message and execute the appropriate action.

Source code in aiperf/common/service/base_component_service.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
async def process_command_message(self, message: CommandMessage) -> None:
    """Process a command message received from the controller.

    This method will process the command message and execute the appropriate action.
    """
    if message.target_service_id and message.target_service_id != self.service_id:
        return  # Ignore commands meant for other services
    if (
        message.target_service_type
        and message.target_service_type != self.service_type
    ):
        return  # Ignore commands meant for other services

    self.logger.debug(
        "%s: Processing command message: %s", self.service_id, message
    )
    cmd = message.command
    response_data = None
    try:
        if cmd == CommandType.PROFILE_START:
            response_data = await self.start()

        elif cmd == CommandType.SHUTDOWN:
            self.logger.debug("%s: Received stop command", self.service_id)
            self.stop_event.set()

        elif cmd == CommandType.PROFILE_CONFIGURE:
            await self.run_hooks(AIPerfHook.ON_CONFIGURE, message)

        elif cmd in self._command_callbacks:
            response_data = await self._command_callbacks[cmd](message)

        else:
            raise self._service_error(
                f"Received unknown command: {cmd}",
            )

        # Publish the success response
        await self.pub_client.publish(
            CommandResponseMessage(
                service_id=self.service_id,
                command=cmd,
                command_id=message.command_id,
                status=CommandResponseStatus.SUCCESS,
                data=response_data,
            ),
        )

    except Exception as e:
        # Publish the failure response
        await self.pub_client.publish(
            CommandResponseMessage(
                service_id=self.service_id,
                command=cmd,
                command_id=message.command_id,
                status=CommandResponseStatus.FAILURE,
                error=ErrorDetails.from_exception(e),
            ),
        )

register() async

Publish a registration request to the system controller.

This method should be called after the service has been initialized and is ready to start processing messages.

Source code in aiperf/common/service/base_component_service.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
async def register(self) -> None:
    """Publish a registration request to the system controller.

    This method should be called after the service has been initialized and is
    ready to start processing messages.
    """
    self.logger.info(
        "Attempting to register service %s (%s) with system controller",
        self.service_type,
        self.service_id,
    )
    try:
        await self.pub_client.publish(
            message=self.create_registration_message(),
        )
    except Exception as e:
        raise self._service_error("Failed to register service") from e

register_command_callback(cmd, callback)

Register a single callback for a command.

Source code in aiperf/common/service/base_component_service.py
195
196
197
198
199
200
201
def register_command_callback(
    self,
    cmd: CommandType,
    callback: Callable[[CommandMessage], Awaitable[None]],
) -> None:
    """Register a single callback for a command."""
    self._command_callbacks[cmd] = callback

send_heartbeat() async

Send a heartbeat notification to the system controller.

Source code in aiperf/common/service/base_component_service.py
106
107
108
109
110
111
112
113
114
115
async def send_heartbeat(self) -> None:
    """Send a heartbeat notification to the system controller."""
    heartbeat_message = self.create_heartbeat_message()
    self.logger.debug("Sending heartbeat: %s", heartbeat_message)
    try:
        await self.pub_client.publish(
            message=heartbeat_message,
        )
    except Exception as e:
        raise self._service_error("Failed to send heartbeat") from e

aiperf.common.service.base_controller_service

BaseControllerService

Bases: BaseService

Base class for all controller services, such as the System Controller.

This class provides a common interface for all controller services in the AIPerf framework. It inherits from the BaseService class and implements the required methods for controller services.

It extends the BaseService by: - Starting the service automatically when the run hook is called - Helpers to create command messages to be sent to a specific service - Request the appropriate communication clients for a controller service

Source code in aiperf/common/service/base_controller_service.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class BaseControllerService(BaseService):
    """Base class for all controller services, such as the System Controller.

    This class provides a common interface for all controller services in the AIPerf
    framework. It inherits from the BaseService class and implements the required
    methods for controller services.

    It extends the BaseService by:
    - Starting the service automatically when the run hook is called
    - Helpers to create command messages to be sent to a specific service
    - Request the appropriate communication clients for a controller service
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
        **kwargs,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

    @on_run
    async def _on_run(self) -> None:
        """Automatically start the service when the run hook is called."""
        await self.start()

    def create_command_message(
        self,
        command: CommandType,
        target_service_id: str | None,
        target_service_type: ServiceType | None = None,
        data: BaseModel | None = None,
    ) -> CommandMessage:
        """Create a command message to be sent to a specific service.

        Args:
            command: The command to send
            target_service_id: The ID of the service to send the command to
            target_service_type: The type of the service to send the command to
            data: Optional data to send with the command.

        Returns:
            A command message
        """
        return CommandMessage(
            service_id=self.service_id,
            command=command,
            target_service_id=target_service_id,
            target_service_type=target_service_type,
            data=data,
        )

create_command_message(command, target_service_id, target_service_type=None, data=None)

Create a command message to be sent to a specific service.

Parameters:

Name Type Description Default
command CommandType

The command to send

required
target_service_id str | None

The ID of the service to send the command to

required
target_service_type ServiceType | None

The type of the service to send the command to

None
data BaseModel | None

Optional data to send with the command.

None

Returns:

Type Description
CommandMessage

A command message

Source code in aiperf/common/service/base_controller_service.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def create_command_message(
    self,
    command: CommandType,
    target_service_id: str | None,
    target_service_type: ServiceType | None = None,
    data: BaseModel | None = None,
) -> CommandMessage:
    """Create a command message to be sent to a specific service.

    Args:
        command: The command to send
        target_service_id: The ID of the service to send the command to
        target_service_type: The type of the service to send the command to
        data: Optional data to send with the command.

    Returns:
        A command message
    """
    return CommandMessage(
        service_id=self.service_id,
        command=command,
        target_service_id=target_service_id,
        target_service_type=target_service_type,
        data=data,
    )

aiperf.common.service.base_service

BaseService

Bases: BaseServiceInterface, ABC, AIPerfTaskMixin, AIPerfLoggerMixin

Base class for all AIPerf services, providing common functionality for communication, state management, and lifecycle operations.

This class provides the foundation for implementing the various services of the AIPerf system. Some of the abstract methods are implemented here, while others are still required to be implemented by derived classes.

Source code in aiperf/common/service/base_service.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
@supports_hooks(
    AIPerfHook.ON_INIT,
    AIPerfHook.ON_RUN,
    AIPerfHook.ON_CONFIGURE,
    AIPerfHook.ON_START,
    AIPerfHook.ON_STOP,
    AIPerfHook.ON_CLEANUP,
    AIPerfHook.ON_SET_STATE,
    AIPerfTaskHook.AIPERF_TASK,
)
class BaseService(BaseServiceInterface, ABC, AIPerfTaskMixin, AIPerfLoggerMixin):
    """Base class for all AIPerf services, providing common functionality for
    communication, state management, and lifecycle operations.

    This class provides the foundation for implementing the various services of the
    AIPerf system. Some of the abstract methods are implemented here, while others
    are still required to be implemented by derived classes.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
        **kwargs,
    ) -> None:
        self.service_id: str = (
            service_id or f"{self.service_type}_{uuid.uuid4().hex[:8]}"
        )
        self.service_config = service_config
        self.user_config = user_config

        self._state: ServiceState = ServiceState.UNKNOWN

        super().__init__(
            service_id=service_id,
            service_config=service_config,
            user_config=user_config,
            logger_name=self.service_id,
            **kwargs,
        )

        self.debug(
            lambda: f"__init__ {self.service_type} service (id: {self.service_id})"
        )

        self._state: ServiceState = ServiceState.UNKNOWN

        self.stop_event = asyncio.Event()
        self.initialized_event = asyncio.Event()

        self.comms: BaseCommunication = CommunicationFactory.create_instance(
            self.service_config.comm_backend,
            config=self.service_config.comm_config,
        )
        self.sub_client = self.comms.create_sub_client(
            CommunicationClientAddressType.EVENT_BUS_PROXY_BACKEND
        )  # type: ignore
        self.pub_client = self.comms.create_pub_client(
            CommunicationClientAddressType.EVENT_BUS_PROXY_FRONTEND
        )  # type: ignore

        try:
            import setproctitle

            setproctitle.setproctitle(f"aiperf {self.service_id}")
        except Exception:
            # setproctitle is not available on all platforms, so we ignore the error
            self.logger.debug("Failed to set process title, ignoring")

        self.logger.debug(
            "BaseService._init__ finished for %s", self.__class__.__name__
        )

    @property
    def state(self) -> ServiceState:
        """The current state of the service."""
        return self._state

    @property
    def is_initialized(self) -> bool:
        """Check if service is initialized.

        Returns:
            True if service is initialized, False otherwise
        """
        return self.initialized_event.is_set()

    def _service_error(self, message: str) -> ServiceError:
        return ServiceError(
            message=message,
            service_type=self.service_type,
            service_id=self.service_id,
        )

    # Note: Not using as a setter so it can be overridden by derived classes and still
    # be async
    async def set_state(self, state: ServiceState) -> None:
        """Set the state of the service. This method implements
        the `BaseServiceInterface.set_state` method.

        This method will:
        - Set the service state to the given state
        - Call all registered `AIPerfHook.ON_SET_STATE` hooks
        """
        self._state = state
        await self.run_hooks(AIPerfHook.ON_SET_STATE, state)

    async def initialize(self) -> None:
        """Initialize the service communication and signal handlers. This method implements
        the `BaseServiceInterface.initialize` method.

        This method will:
        - Set the service to `ServiceState.INITIALIZING` state
        - Initialize communication
        - Call all registered `AIPerfHook.ON_INIT` hooks
        - Set the service to `ServiceState.READY` state
        - Set the initialized asyncio event
        """
        self._state = ServiceState.INITIALIZING

        await self.comms.initialize()

        # Initialize any derived service components
        await self.run_hooks(AIPerfHook.ON_INIT)
        await self.set_state(ServiceState.READY)

        self.initialized_event.set()

    async def run_forever(self) -> None:
        """Run the service in a loop until the stop event is set. This method implements
        the `BaseServiceInterface.run_forever` method.

        This method will:
        - Call the initialize method to initialize the service
        - Call all registered `AIPerfHook.RUN` hooks
        - Wait for the stop event to be set
        - Shuts down the service when the stop event is set

        This method will be called as the main entry point for the service.
        """
        try:
            self.logger.debug(
                "Running %s service (id: %s)", self.service_type, self.service_id
            )

            await self.initialize()
            await self.run_hooks(AIPerfHook.ON_RUN)

        except asyncio.CancelledError:
            self.logger.debug("Service %s execution cancelled", self.service_type)
            return

        except AIPerfError:
            raise  # re-raise it up the stack

        except Exception as e:
            self.logger.exception("Service %s execution failed:", self.service_type)
            _ = await self.set_state(ServiceState.ERROR)
            raise self._service_error("Service execution failed") from e

        await self._forever_loop()

    async def _forever_loop(self) -> None:
        """
        This method will be called by the `run_forever` method to allow the service to run
        indefinitely. This method is not expected to be overridden by derived classes.

        This method will:
        - Wait for the stop event to be set
        - Shuts down the service when the stop event is set
        """
        while not self.stop_event.is_set():
            try:
                self.logger.debug(
                    "Service %s waiting for stop event", self.service_type
                )
                # Wait forever for the stop event to be set
                await self.stop_event.wait()

            except asyncio.CancelledError:
                self.logger.debug(
                    "Service %s received CancelledError, exiting",
                    self.service_type,
                )
                break

            except Exception:
                self.logger.exception(
                    "Caught unexpected exception in service %s execution",
                    self.service_type,
                )

        # Shutdown the service
        try:
            await self.stop()
        except Exception:
            self.logger.exception(
                "Caught unexpected exception in service %s stop",
                self.service_type,
            )

    async def start(self) -> None:
        """Start the service and its components. This method implements
        the `BaseServiceInterface.start` method.

        This method should be called to start the service after it has been initialized
        and configured.

        This method will:
        - Set the service to `ServiceState.STARTING` state
        - Call all registered `AIPerfHook.ON_START` hooks
        - Set the service to `ServiceState.RUNNING` state
        """

        try:
            self.logger.debug(
                "Starting %s service (id: %s)", self.service_type, self.service_id
            )
            _ = await self.set_state(ServiceState.STARTING)

            await self.run_hooks(AIPerfHook.ON_START)

            _ = await self.set_state(ServiceState.RUNNING)

        except asyncio.CancelledError:
            pass

        except Exception as e:
            self._state = ServiceState.ERROR
            raise self._service_error("Failed to start service") from e

    async def stop(self) -> None:
        """Stop the service and clean up its components. This method implements
        the `BaseServiceInterface.stop` method.

        This method will:
        - Set the service to `ServiceState.STOPPING` state
        - Call all registered `AIPerfHook.ON_STOP` hooks
        - Shutdown the service communication component
        - Call all registered `AIPerfHook.ON_CLEANUP` hooks
        - Set the service to `ServiceState.STOPPED` state
        """
        try:
            if self.state == ServiceState.STOPPED:
                self.logger.warning(
                    "Service %s state %s is already STOPPED, ignoring stop request",
                    self.service_type,
                    self.state,
                )
                return

            self._state = ServiceState.STOPPING

            # Signal the run method to exit if it hasn't already
            if not self.stop_event.is_set():
                self.stop_event.set()

            cancelled_error = None
            # Custom stop logic implemented by derived classes
            try:
                await self.run_hooks(AIPerfHook.ON_STOP)
            except asyncio.CancelledError as e:
                cancelled_error = e

            # Shutdown communication component
            if self.comms and not self.comms.stop_requested:
                try:
                    await self.comms.shutdown()
                except asyncio.CancelledError as e:
                    cancelled_error = e

            # Custom cleanup logic implemented by derived classes
            try:
                await self.run_hooks(AIPerfHook.ON_CLEANUP)
            except asyncio.CancelledError as e:
                cancelled_error = e

            # Set the state to STOPPED. Communications are shutdown, so we don't need to
            # publish a status message
            self._state = ServiceState.STOPPED
            if self.service_type not in (
                ServiceType.WORKER,
                ServiceType.WORKER_MANAGER,
            ):
                self.logger.debug(
                    "Service %s (id: %s) stopped", self.service_type, self.service_id
                )

            # Re-raise the cancelled error if it was raised during the stop hooks
            if cancelled_error:
                raise cancelled_error

        except Exception as e:
            self._state = ServiceState.ERROR
            raise self._service_error("Failed to stop service") from e

    async def configure(self, message: Message) -> None:
        """Configure the service with the given configuration. This method implements
        the `BaseServiceInterface.configure` method.

        This method will:
        - Call all registered AIPerfHook.ON_CONFIGURE hooks
        """
        await self.run_hooks(AIPerfHook.ON_CONFIGURE, message)

is_initialized property

Check if service is initialized.

Returns:

Type Description
bool

True if service is initialized, False otherwise

state property

The current state of the service.

configure(message) async

Configure the service with the given configuration. This method implements the BaseServiceInterface.configure method.

This method will: - Call all registered AIPerfHook.ON_CONFIGURE hooks

Source code in aiperf/common/service/base_service.py
326
327
328
329
330
331
332
333
async def configure(self, message: Message) -> None:
    """Configure the service with the given configuration. This method implements
    the `BaseServiceInterface.configure` method.

    This method will:
    - Call all registered AIPerfHook.ON_CONFIGURE hooks
    """
    await self.run_hooks(AIPerfHook.ON_CONFIGURE, message)

initialize() async

Initialize the service communication and signal handlers. This method implements the BaseServiceInterface.initialize method.

This method will: - Set the service to ServiceState.INITIALIZING state - Initialize communication - Call all registered AIPerfHook.ON_INIT hooks - Set the service to ServiceState.READY state - Set the initialized asyncio event

Source code in aiperf/common/service/base_service.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
async def initialize(self) -> None:
    """Initialize the service communication and signal handlers. This method implements
    the `BaseServiceInterface.initialize` method.

    This method will:
    - Set the service to `ServiceState.INITIALIZING` state
    - Initialize communication
    - Call all registered `AIPerfHook.ON_INIT` hooks
    - Set the service to `ServiceState.READY` state
    - Set the initialized asyncio event
    """
    self._state = ServiceState.INITIALIZING

    await self.comms.initialize()

    # Initialize any derived service components
    await self.run_hooks(AIPerfHook.ON_INIT)
    await self.set_state(ServiceState.READY)

    self.initialized_event.set()

run_forever() async

Run the service in a loop until the stop event is set. This method implements the BaseServiceInterface.run_forever method.

This method will: - Call the initialize method to initialize the service - Call all registered AIPerfHook.RUN hooks - Wait for the stop event to be set - Shuts down the service when the stop event is set

This method will be called as the main entry point for the service.

Source code in aiperf/common/service/base_service.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
async def run_forever(self) -> None:
    """Run the service in a loop until the stop event is set. This method implements
    the `BaseServiceInterface.run_forever` method.

    This method will:
    - Call the initialize method to initialize the service
    - Call all registered `AIPerfHook.RUN` hooks
    - Wait for the stop event to be set
    - Shuts down the service when the stop event is set

    This method will be called as the main entry point for the service.
    """
    try:
        self.logger.debug(
            "Running %s service (id: %s)", self.service_type, self.service_id
        )

        await self.initialize()
        await self.run_hooks(AIPerfHook.ON_RUN)

    except asyncio.CancelledError:
        self.logger.debug("Service %s execution cancelled", self.service_type)
        return

    except AIPerfError:
        raise  # re-raise it up the stack

    except Exception as e:
        self.logger.exception("Service %s execution failed:", self.service_type)
        _ = await self.set_state(ServiceState.ERROR)
        raise self._service_error("Service execution failed") from e

    await self._forever_loop()

set_state(state) async

Set the state of the service. This method implements the BaseServiceInterface.set_state method.

This method will: - Set the service state to the given state - Call all registered AIPerfHook.ON_SET_STATE hooks

Source code in aiperf/common/service/base_service.py
126
127
128
129
130
131
132
133
134
135
async def set_state(self, state: ServiceState) -> None:
    """Set the state of the service. This method implements
    the `BaseServiceInterface.set_state` method.

    This method will:
    - Set the service state to the given state
    - Call all registered `AIPerfHook.ON_SET_STATE` hooks
    """
    self._state = state
    await self.run_hooks(AIPerfHook.ON_SET_STATE, state)

start() async

Start the service and its components. This method implements the BaseServiceInterface.start method.

This method should be called to start the service after it has been initialized and configured.

This method will: - Set the service to ServiceState.STARTING state - Call all registered AIPerfHook.ON_START hooks - Set the service to ServiceState.RUNNING state

Source code in aiperf/common/service/base_service.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
async def start(self) -> None:
    """Start the service and its components. This method implements
    the `BaseServiceInterface.start` method.

    This method should be called to start the service after it has been initialized
    and configured.

    This method will:
    - Set the service to `ServiceState.STARTING` state
    - Call all registered `AIPerfHook.ON_START` hooks
    - Set the service to `ServiceState.RUNNING` state
    """

    try:
        self.logger.debug(
            "Starting %s service (id: %s)", self.service_type, self.service_id
        )
        _ = await self.set_state(ServiceState.STARTING)

        await self.run_hooks(AIPerfHook.ON_START)

        _ = await self.set_state(ServiceState.RUNNING)

    except asyncio.CancelledError:
        pass

    except Exception as e:
        self._state = ServiceState.ERROR
        raise self._service_error("Failed to start service") from e

stop() async

Stop the service and clean up its components. This method implements the BaseServiceInterface.stop method.

This method will: - Set the service to ServiceState.STOPPING state - Call all registered AIPerfHook.ON_STOP hooks - Shutdown the service communication component - Call all registered AIPerfHook.ON_CLEANUP hooks - Set the service to ServiceState.STOPPED state

Source code in aiperf/common/service/base_service.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
async def stop(self) -> None:
    """Stop the service and clean up its components. This method implements
    the `BaseServiceInterface.stop` method.

    This method will:
    - Set the service to `ServiceState.STOPPING` state
    - Call all registered `AIPerfHook.ON_STOP` hooks
    - Shutdown the service communication component
    - Call all registered `AIPerfHook.ON_CLEANUP` hooks
    - Set the service to `ServiceState.STOPPED` state
    """
    try:
        if self.state == ServiceState.STOPPED:
            self.logger.warning(
                "Service %s state %s is already STOPPED, ignoring stop request",
                self.service_type,
                self.state,
            )
            return

        self._state = ServiceState.STOPPING

        # Signal the run method to exit if it hasn't already
        if not self.stop_event.is_set():
            self.stop_event.set()

        cancelled_error = None
        # Custom stop logic implemented by derived classes
        try:
            await self.run_hooks(AIPerfHook.ON_STOP)
        except asyncio.CancelledError as e:
            cancelled_error = e

        # Shutdown communication component
        if self.comms and not self.comms.stop_requested:
            try:
                await self.comms.shutdown()
            except asyncio.CancelledError as e:
                cancelled_error = e

        # Custom cleanup logic implemented by derived classes
        try:
            await self.run_hooks(AIPerfHook.ON_CLEANUP)
        except asyncio.CancelledError as e:
            cancelled_error = e

        # Set the state to STOPPED. Communications are shutdown, so we don't need to
        # publish a status message
        self._state = ServiceState.STOPPED
        if self.service_type not in (
            ServiceType.WORKER,
            ServiceType.WORKER_MANAGER,
        ):
            self.logger.debug(
                "Service %s (id: %s) stopped", self.service_type, self.service_id
            )

        # Re-raise the cancelled error if it was raised during the stop hooks
        if cancelled_error:
            raise cancelled_error

    except Exception as e:
        self._state = ServiceState.ERROR
        raise self._service_error("Failed to stop service") from e

aiperf.common.service.base_service_interface

BaseServiceInterface

Bases: ABC

Base interface for all services.

This class provides the base foundation for which every service should provide. Some methods are required to be implemented by derived classes, while others are meant to be implemented by the base class.

Source code in aiperf/common/service/base_service_interface.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class BaseServiceInterface(ABC):
    """Base interface for all services.

    This class provides the base foundation for which every service should provide. Some
    methods are required to be implemented by derived classes, while others are
    meant to be implemented by the base class.
    """

    @property
    @abstractmethod
    def service_type(self) -> ServiceType:
        """The type/name of the service.

        This property should be implemented by derived classes to specify the
        type/name of the service."""
        # TODO: We can do this better by using a decorator to set the service type
        pass

    @abstractmethod
    async def set_state(self, state: ServiceState) -> None:
        """Set the state of the service.

        This method will be implemented by the base class, and extra
        functionality can be added by derived classes via the `@on_set_state`
        decorator.
        """
        pass

    @abstractmethod
    async def initialize(self) -> None:
        """Initialize the service.

        This method will be implemented by the base class.
        """
        pass

    @abstractmethod
    async def start(self) -> None:
        """Start the service. It should be called after the service has been initialized
        and configured.

        This method will be implemented by the base class, and extra
        functionality can be added by derived classes via the `@on_start`
        decorator.
        """
        pass

    @abstractmethod
    async def stop(self) -> None:
        """Stop the service.

        This method will be implemented by the base class, and extra
        functionality can be added by derived classes via the `@on_stop`
        decorator.
        """
        pass

    @abstractmethod
    async def configure(self, message: Message) -> None:
        """Configure the service with the given configuration.

        This method will be implemented by the base class, and extra
        functionality can be added by derived classes via the `@on_configure`
        decorator.
        """
        pass

    @abstractmethod
    async def run_forever(self) -> None:
        """Run the service. This method will be the primary entry point for the service
        and will be called by the bootstrap script. It should not return until the
        service is completely shutdown.

        This method will be implemented by the base class. Any additional
        functionality can be added by derived classes via the `@on_run`
        decorator.
        """
        pass

    @abstractmethod
    async def _forever_loop(self) -> None:
        """Run the service in a loop until the stop event is set. This method will be
        called by the `run` method to allow the service to run indefinitely.

        This method will be implemented by the base class, and is not expected to be
        overridden by derived classes.
        """
        pass

service_type abstractmethod property

The type/name of the service.

This property should be implemented by derived classes to specify the type/name of the service.

configure(message) abstractmethod async

Configure the service with the given configuration.

This method will be implemented by the base class, and extra functionality can be added by derived classes via the @on_configure decorator.

Source code in aiperf/common/service/base_service_interface.py
66
67
68
69
70
71
72
73
74
@abstractmethod
async def configure(self, message: Message) -> None:
    """Configure the service with the given configuration.

    This method will be implemented by the base class, and extra
    functionality can be added by derived classes via the `@on_configure`
    decorator.
    """
    pass

initialize() abstractmethod async

Initialize the service.

This method will be implemented by the base class.

Source code in aiperf/common/service/base_service_interface.py
37
38
39
40
41
42
43
@abstractmethod
async def initialize(self) -> None:
    """Initialize the service.

    This method will be implemented by the base class.
    """
    pass

run_forever() abstractmethod async

Run the service. This method will be the primary entry point for the service and will be called by the bootstrap script. It should not return until the service is completely shutdown.

This method will be implemented by the base class. Any additional functionality can be added by derived classes via the @on_run decorator.

Source code in aiperf/common/service/base_service_interface.py
76
77
78
79
80
81
82
83
84
85
86
@abstractmethod
async def run_forever(self) -> None:
    """Run the service. This method will be the primary entry point for the service
    and will be called by the bootstrap script. It should not return until the
    service is completely shutdown.

    This method will be implemented by the base class. Any additional
    functionality can be added by derived classes via the `@on_run`
    decorator.
    """
    pass

set_state(state) abstractmethod async

Set the state of the service.

This method will be implemented by the base class, and extra functionality can be added by derived classes via the @on_set_state decorator.

Source code in aiperf/common/service/base_service_interface.py
27
28
29
30
31
32
33
34
35
@abstractmethod
async def set_state(self, state: ServiceState) -> None:
    """Set the state of the service.

    This method will be implemented by the base class, and extra
    functionality can be added by derived classes via the `@on_set_state`
    decorator.
    """
    pass

start() abstractmethod async

Start the service. It should be called after the service has been initialized and configured.

This method will be implemented by the base class, and extra functionality can be added by derived classes via the @on_start decorator.

Source code in aiperf/common/service/base_service_interface.py
45
46
47
48
49
50
51
52
53
54
@abstractmethod
async def start(self) -> None:
    """Start the service. It should be called after the service has been initialized
    and configured.

    This method will be implemented by the base class, and extra
    functionality can be added by derived classes via the `@on_start`
    decorator.
    """
    pass

stop() abstractmethod async

Stop the service.

This method will be implemented by the base class, and extra functionality can be added by derived classes via the @on_stop decorator.

Source code in aiperf/common/service/base_service_interface.py
56
57
58
59
60
61
62
63
64
@abstractmethod
async def stop(self) -> None:
    """Stop the service.

    This method will be implemented by the base class, and extra
    functionality can be added by derived classes via the `@on_stop`
    decorator.
    """
    pass

aiperf.common.tokenizer

Tokenizer

This class provides a simplified interface for using Huggingface tokenizers, with default arguments for common operations.

Source code in aiperf/common/tokenizer.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class Tokenizer:
    """
    This class provides a simplified interface for using Huggingface
    tokenizers, with default arguments for common operations.
    """

    def __init__(self) -> None:
        """
        Initialize the tokenizer with default values for call, encode, and decode.
        """
        self._tokenizer = None
        self._call_args = {"add_special_tokens": False}
        self._encode_args = {"add_special_tokens": False}
        self._decode_args = {"skip_special_tokens": True}

    @classmethod
    def from_pretrained(
        cls,
        name: str,
        trust_remote_code: bool = False,
        revision: str = "main",
    ) -> "Tokenizer":
        """
        Factory to load a tokenizer for the given pretrained model name.

        Args:
            name: The name or path of the pretrained tokenizer model.
            trust_remote_code: Whether to trust remote code when loading the tokenizer.
            revision: The specific model version to use.
        """
        try:
            tokenizer_cls = cls()
            tokenizer_cls._tokenizer = AutoTokenizer.from_pretrained(
                name, trust_remote_code=trust_remote_code, revision=revision
            )
        except Exception as e:
            raise InitializationError(e) from e
        return tokenizer_cls

    def __call__(self, text, **kwargs) -> "BatchEncoding":
        """
        Call the underlying Huggingface tokenizer with default arguments,
        which can be overridden by kwargs.

        Args:
            text: The input text to tokenize.

        Returns:
            A BatchEncoding object containing the tokenized output.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer(text, **{**self._call_args, **kwargs})

    def encode(self, text, **kwargs) -> list[int]:
        """
        Encode the input text into a list of token IDs.

        This method calls the underlying Huggingface tokenizer's encode
        method with default arguments, which can be overridden by kwargs.

        Args:
            text: The input text to encode.

        Returns:
            A list of token IDs.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.encode(text, **{**self._encode_args, **kwargs})

    def decode(self, token_ids, **kwargs) -> str:
        """
        Decode a list of token IDs back into a string.

        This method calls the underlying Huggingface tokenizer's decode
        method with default arguments, which can be overridden by kwargs.

        Args:
            token_ids: A list of token IDs to decode.

        Returns:
            The decoded string.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.decode(token_ids, **{**self._decode_args, **kwargs})

    @property
    def bos_token_id(self) -> int:
        """
        Return the beginning-of-sequence (BOS) token ID.
        """
        if self._tokenizer is None:
            raise NotInitializedError("Tokenizer is not initialized.")
        return self._tokenizer.bos_token_id

    def __repr__(self) -> str:
        """
        Return a string representation of the underlying tokenizer.

        Returns:
            The string representation of the tokenizer.
        """
        return self._tokenizer.__repr__()

    def __str__(self) -> str:
        """
        Return a user-friendly string representation of the underlying tokenizer.

        Returns:
            The string representation of the tokenizer.
        """
        return self._tokenizer.__str__()

bos_token_id property

Return the beginning-of-sequence (BOS) token ID.

__call__(text, **kwargs)

Call the underlying Huggingface tokenizer with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
text

The input text to tokenize.

required

Returns:

Type Description
BatchEncoding

A BatchEncoding object containing the tokenized output.

Source code in aiperf/common/tokenizer.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __call__(self, text, **kwargs) -> "BatchEncoding":
    """
    Call the underlying Huggingface tokenizer with default arguments,
    which can be overridden by kwargs.

    Args:
        text: The input text to tokenize.

    Returns:
        A BatchEncoding object containing the tokenized output.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer(text, **{**self._call_args, **kwargs})

__init__()

Initialize the tokenizer with default values for call, encode, and decode.

Source code in aiperf/common/tokenizer.py
28
29
30
31
32
33
34
35
def __init__(self) -> None:
    """
    Initialize the tokenizer with default values for call, encode, and decode.
    """
    self._tokenizer = None
    self._call_args = {"add_special_tokens": False}
    self._encode_args = {"add_special_tokens": False}
    self._decode_args = {"skip_special_tokens": True}

__repr__()

Return a string representation of the underlying tokenizer.

Returns:

Type Description
str

The string representation of the tokenizer.

Source code in aiperf/common/tokenizer.py
119
120
121
122
123
124
125
126
def __repr__(self) -> str:
    """
    Return a string representation of the underlying tokenizer.

    Returns:
        The string representation of the tokenizer.
    """
    return self._tokenizer.__repr__()

__str__()

Return a user-friendly string representation of the underlying tokenizer.

Returns:

Type Description
str

The string representation of the tokenizer.

Source code in aiperf/common/tokenizer.py
128
129
130
131
132
133
134
135
def __str__(self) -> str:
    """
    Return a user-friendly string representation of the underlying tokenizer.

    Returns:
        The string representation of the tokenizer.
    """
    return self._tokenizer.__str__()

decode(token_ids, **kwargs)

Decode a list of token IDs back into a string.

This method calls the underlying Huggingface tokenizer's decode method with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
token_ids

A list of token IDs to decode.

required

Returns:

Type Description
str

The decoded string.

Source code in aiperf/common/tokenizer.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def decode(self, token_ids, **kwargs) -> str:
    """
    Decode a list of token IDs back into a string.

    This method calls the underlying Huggingface tokenizer's decode
    method with default arguments, which can be overridden by kwargs.

    Args:
        token_ids: A list of token IDs to decode.

    Returns:
        The decoded string.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer.decode(token_ids, **{**self._decode_args, **kwargs})

encode(text, **kwargs)

Encode the input text into a list of token IDs.

This method calls the underlying Huggingface tokenizer's encode method with default arguments, which can be overridden by kwargs.

Parameters:

Name Type Description Default
text

The input text to encode.

required

Returns:

Type Description
list[int]

A list of token IDs.

Source code in aiperf/common/tokenizer.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def encode(self, text, **kwargs) -> list[int]:
    """
    Encode the input text into a list of token IDs.

    This method calls the underlying Huggingface tokenizer's encode
    method with default arguments, which can be overridden by kwargs.

    Args:
        text: The input text to encode.

    Returns:
        A list of token IDs.
    """
    if self._tokenizer is None:
        raise NotInitializedError("Tokenizer is not initialized.")
    return self._tokenizer.encode(text, **{**self._encode_args, **kwargs})

from_pretrained(name, trust_remote_code=False, revision='main') classmethod

Factory to load a tokenizer for the given pretrained model name.

Parameters:

Name Type Description Default
name str

The name or path of the pretrained tokenizer model.

required
trust_remote_code bool

Whether to trust remote code when loading the tokenizer.

False
revision str

The specific model version to use.

'main'
Source code in aiperf/common/tokenizer.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@classmethod
def from_pretrained(
    cls,
    name: str,
    trust_remote_code: bool = False,
    revision: str = "main",
) -> "Tokenizer":
    """
    Factory to load a tokenizer for the given pretrained model name.

    Args:
        name: The name or path of the pretrained tokenizer model.
        trust_remote_code: Whether to trust remote code when loading the tokenizer.
        revision: The specific model version to use.
    """
    try:
        tokenizer_cls = cls()
        tokenizer_cls._tokenizer = AutoTokenizer.from_pretrained(
            name, trust_remote_code=trust_remote_code, revision=revision
        )
    except Exception as e:
        raise InitializationError(e) from e
    return tokenizer_cls

aiperf.common.types

MessageTypeT = MessageType | str module-attribute

Alias for the MessageType being an enum or a custom string for user-defined message types.

aiperf.common.utils

call_all_functions(funcs, *args, **kwargs) async

Call all functions in the list with the given name.

Parameters:

Name Type Description Default
obj

The object to call the functions on.

required
func_names

The names of the functions to call.

required
*args

The arguments to pass to the functions.

()
**kwargs

The keyword arguments to pass to the functions.

{}

Raises:

Type Description
AIPerfMultiError

If any of the functions raise an exception.

Source code in aiperf/common/utils.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
async def call_all_functions(funcs: list[Callable], *args, **kwargs) -> None:
    """Call all functions in the list with the given name.

    Args:
        obj: The object to call the functions on.
        func_names: The names of the functions to call.
        *args: The arguments to pass to the functions.
        **kwargs: The keyword arguments to pass to the functions.

    Raises:
        AIPerfMultiError: If any of the functions raise an exception.
    """

    exceptions = []
    for func in funcs:
        try:
            if inspect.iscoroutinefunction(func):
                await func(*args, **kwargs)
            else:
                func(*args, **kwargs)
        except Exception as e:
            # TODO: error handling, logging
            traceback.print_exc()
            exceptions.append(e)

    if len(exceptions) > 0:
        raise AIPerfMultiError("Errors calling functions", exceptions)

call_all_functions_self(self_, funcs, *args, **kwargs) async

Call all functions in the list with the given name.

Parameters:

Name Type Description Default
obj

The object to call the functions on.

required
func_names

The names of the functions to call.

required
*args

The arguments to pass to the functions.

()
**kwargs

The keyword arguments to pass to the functions.

{}

Raises:

Type Description
AIPerfMultiError

If any of the functions raise an exception.

Source code in aiperf/common/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
async def call_all_functions_self(
    self_: object, funcs: list[Callable], *args, **kwargs
) -> None:
    """Call all functions in the list with the given name.

    Args:
        obj: The object to call the functions on.
        func_names: The names of the functions to call.
        *args: The arguments to pass to the functions.
        **kwargs: The keyword arguments to pass to the functions.

    Raises:
        AIPerfMultiError: If any of the functions raise an exception.
    """

    exceptions = []
    for func in funcs:
        try:
            if inspect.iscoroutinefunction(func):
                await func(self_, *args, **kwargs)
            else:
                func(self_, *args, **kwargs)
        except Exception as e:
            # TODO: error handling, logging
            traceback.print_exc()
            exceptions.append(e)

    if len(exceptions) > 0:
        raise AIPerfMultiError("Errors calling functions", exceptions)

load_json_str(json_str, func=lambda x: x)

Deserializes JSON encoded string into Python object.

Parameters:

Name Type Description Default
- json_str

string JSON encoded string

required
- func

callable A function that takes deserialized JSON object. This can be used to run validation checks on the object. Defaults to identity function.

required
Source code in aiperf/common/utils.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def load_json_str(json_str: str, func: Callable = lambda x: x) -> dict[str, Any]:
    """
    Deserializes JSON encoded string into Python object.

    Args:
      - json_str: string
          JSON encoded string
      - func: callable
          A function that takes deserialized JSON object. This can be used to
          run validation checks on the object. Defaults to identity function.
    """
    try:
        # Note: orjson may not parse JSON the same way as Python's standard json library,
        # notably being stricter on UTF-8 conformance.
        # Refer to https://github.com/ijl/orjson?tab=readme-ov-file#str for details.
        return func(orjson.loads(json_str))
    except orjson.JSONDecodeError:
        snippet = json_str[:200] + ("..." if len(json_str) > 200 else "")
        logger.error("Failed to parse JSON string: '%s'", snippet)
        raise

yield_to_event_loop() async

Yield to the event loop. This forces the current coroutine to yield and allow other coroutines to run, preventing starvation. Use this when you do not want to delay your coroutine via sleep, but still want to allow other coroutines to run if there is a potential for an infinite loop.

Source code in aiperf/common/utils.py
101
102
103
104
105
106
107
async def yield_to_event_loop() -> None:
    """Yield to the event loop. This forces the current coroutine to yield and allow
    other coroutines to run, preventing starvation. Use this when you do not want to
    delay your coroutine via sleep, but still want to allow other coroutines to run if
    there is a potential for an infinite loop.
    """
    await asyncio.sleep(0)

aiperf.data_exporter.console_error_exporter

ConsoleErrorExporter

A class that exports error data to the console

Source code in aiperf/data_exporter/console_error_exporter.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@DataExporterFactory.register(DataExporterType.CONSOLE_ERROR)
class ConsoleErrorExporter:
    """A class that exports error data to the console"""

    def __init__(self, exporter_config: ExporterConfig):
        self._results = exporter_config.results

    async def export(self, width: int | None = None) -> None:
        console = Console()

        if len(self._results.errors_by_type) > 0:
            table = Table(title=self._get_title(), width=width)
            table.add_column("Code", justify="right", style="yellow")
            table.add_column("Type", justify="right", style="yellow")
            table.add_column("Message", justify="left", style="yellow")
            table.add_column("Count", justify="right", style="yellow")
            self._construct_table(table, self._results.errors_by_type)

            console.print("\n")
            console.print(table)

        if self._results.was_cancelled:
            console.print("[red][bold]Profile run was cancelled early[/bold][/red]")

        console.file.flush()

    def _construct_table(
        self, table: Table, errors_by_type: list[ErrorDetailsCount]
    ) -> None:
        for error_details_count in errors_by_type:
            table.add_row(*self._format_row(error_details_count))

    def _format_row(self, error_details_count: ErrorDetailsCount) -> list[str]:
        details = error_details_count.error_details
        count = error_details_count.count

        return [
            str(details.code) if details.code else "[dim]N/A[/dim]",
            str(details.type) if details.type else "[dim]N/A[/dim]",
            str(details.message),
            f"{count:,}",
        ]

    def _get_title(self) -> str:
        return "[bold][red]NVIDIA AIPerf | Error Summary[/red][/bold]"

aiperf.data_exporter.console_exporter

ConsoleExporter

A class that exports data to the console

Source code in aiperf/data_exporter/console_exporter.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@DataExporterFactory.register(DataExporterType.CONSOLE)
class ConsoleExporter:
    """A class that exports data to the console"""

    STAT_COLUMN_KEYS = ["avg", "min", "max", "p99", "p90", "p75", "std", "count"]

    def __init__(self, exporter_config: ExporterConfig) -> None:
        self._results = exporter_config.results
        self._endpoint_type = exporter_config.input_config.endpoint.type
        self._streaming = exporter_config.input_config.endpoint.streaming

    async def export(self, width: int | None = None) -> None:
        table = Table(title=self._get_title(), width=width)
        table.add_column("Metric", justify="right", style="cyan")
        for key in self.STAT_COLUMN_KEYS:
            table.add_column(key, justify="right", style="green")
        self._construct_table(table, self._results.records)

        console = Console()
        console.print("\n")
        console.print(table)
        console.file.flush()

    def _construct_table(self, table: Table, records: list[MetricResult]) -> None:
        for record in records:
            if self._should_skip(record):
                continue
            table.add_row(*self._format_row(record))

    def _should_skip(self, record: MetricResult) -> bool:
        if self._endpoint_type == "embeddings":
            return False

        return record.streaming_only and not self._streaming

    def _format_row(self, record: MetricResult) -> list[str]:
        row = [f"{record.header} ({record.unit})"]
        for stat in self.STAT_COLUMN_KEYS:
            value = getattr(record, stat, None)
            row.append(
                f"{value:,.2f}"
                if isinstance(value, float)
                else f"{value:,}"
                if isinstance(value, int)
                else "[dim]N/A[/dim]"
            )
        return row

    def _get_title(self) -> str:
        type_titles = {
            "embeddings": "Embeddings Metrics",
            "rankings": "Rankings Metrics",
            "image_retrieval": "Image Retrieval Metrics",
            "multimodal": "Multi-Modal Metrics",
        }
        metric_title = type_titles.get(self._endpoint_type, "LLM Metrics")
        return f"NVIDIA AIPerf | {metric_title}"

aiperf.data_exporter.exporter_config

aiperf.data_exporter.exporter_manager

ExporterManager

ExporterManager is responsible for exporting records using all registered data exporters.

Source code in aiperf/data_exporter/exporter_manager.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ExporterManager:
    """
    ExporterManager is responsible for exporting records using all
    registered data exporters.
    """

    def __init__(self, results: ProfileResultsMessage, input_config: UserConfig):
        self._results = results
        self._input_config = input_config
        self._exporter_classes = DataExporterFactory.get_all_classes()

    async def export_all(self) -> None:
        tasks: list[asyncio.Task] = []
        for exporter_class in self._exporter_classes:
            exporter_config = ExporterConfig(
                results=self._results,
                input_config=self._input_config,
            )
            exporter = exporter_class(exporter_config)
            task = asyncio.create_task(exporter.export())
            tasks.append(task)

        await asyncio.gather(*tasks)

aiperf.data_exporter.json_exporter

JsonExportData

Bases: BaseModel

Data to be exported to a JSON file.

Source code in aiperf/data_exporter/json_exporter.py
18
19
20
21
22
23
24
25
26
class JsonExportData(BaseModel):
    """Data to be exported to a JSON file."""

    input_config: UserConfig | None = None
    records: dict[str, MetricResult] | None = None
    was_cancelled: bool | None = None
    errors_by_type: list[ErrorDetailsCount] | None = None
    start_time: datetime | None = None
    end_time: datetime | None = None

JsonExporter

A class to export records to a JSON file.

Source code in aiperf/data_exporter/json_exporter.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@DataExporterFactory.register(DataExporterType.JSON)
class JsonExporter:
    """
    A class to export records to a JSON file.
    """

    def __init__(self, exporter_config: ExporterConfig) -> None:
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.debug("Initializing JsonExporter with config: %s", exporter_config)
        self._results = exporter_config.results
        self._output_directory = exporter_config.input_config.output.artifact_directory
        self._input_config = exporter_config.input_config

    async def export(self) -> None:
        filename = self._output_directory / "profile_export_aiperf.json"
        self._output_directory.mkdir(parents=True, exist_ok=True)

        start_time = (
            datetime.fromtimestamp(self._results.start_ns / NANOS_PER_SECOND)
            if self._results.start_ns
            else None
        )
        end_time = (
            datetime.fromtimestamp(self._results.end_ns / NANOS_PER_SECOND)
            if self._results.end_ns
            else None
        )

        export_data = JsonExportData(
            input_config=self._input_config,
            records={record.tag: record for record in self._results.records},
            was_cancelled=self._results.was_cancelled,
            errors_by_type=self._results.errors_by_type,
            start_time=start_time,
            end_time=end_time,
        )

        self.logger.debug("Exporting data to JSON file: %s", export_data)
        export_data_json = export_data.model_dump_json(indent=2, exclude_unset=True)
        async with aiofiles.open(filename, "w") as f:
            await f.write(export_data_json)

aiperf.progress.progress_models

BenchmarkSuiteCompletionTrigger

Bases: CaseInsensitiveStrEnum

Determines how the suite completion is determined in order to know how to track the progress.

Source code in aiperf/progress/progress_models.py
101
102
103
104
105
106
107
108
class BenchmarkSuiteCompletionTrigger(CaseInsensitiveStrEnum):
    """Determines how the suite completion is determined in order to know how to track the progress."""

    UNKNOWN = "unknown"
    COMPLETED_SWEEPS = "completed_sweeps"
    COMPLETED_PROFILES = "completed_profiles"
    STABILIZATION_BASED = "stabilization_based"
    CUSTOM = "custom"  # TBD

BenchmarkSuiteProgress

Bases: BaseModel, ABC

State of the suite progress.

Source code in aiperf/progress/progress_models.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
class BenchmarkSuiteProgress(BaseModel, ABC):
    """State of the suite progress."""

    suite_type: BenchmarkSuiteType = Field(
        default=BenchmarkSuiteType.SINGLE_PROFILE,
        description="The type of suite. Default is SINGLE_PROFILE.",
    )
    suite_completion_trigger: BenchmarkSuiteCompletionTrigger = Field(
        default=BenchmarkSuiteCompletionTrigger.COMPLETED_PROFILES,
        description="The trigger of suite completion",
    )
    start_time_ns: int | None = Field(
        default=None,
        description="The overall start time of the suite in nanoseconds. If it has not been started, this will be None.",
    )
    end_time_ns: int | None = Field(
        default=None,
        description="The overall end time of the suite in nanoseconds. If it has not been completed, this will be None.",
    )
    was_cancelled: bool = Field(
        default=False,
        description="Whether the suite was cancelled early",
    )

    @property
    def current_sweep(self) -> SweepProgress | None:
        if not isinstance(self, SweepSuiteProgress) or self.current_sweep_idx is None:
            return None
        return self.sweeps[self.current_sweep_idx]

    @property
    def current_profile(self) -> ProfileProgress | None:
        if isinstance(self, ProfileSuiteProgress):
            if self.current_profile_idx is None or self.current_profile_idx >= len(
                self.profiles
            ):
                return None
            return self.profiles[self.current_profile_idx]

        elif isinstance(self, SweepSuiteProgress):
            if self.current_sweep is None:
                return None
            return self.current_sweep.current_profile

        return None

    @abstractmethod
    def next_profile(self) -> ProfileProgress | None: ...

BenchmarkSuiteType

Bases: CaseInsensitiveStrEnum

Determines the type of suite to know how to track the progress.

Source code in aiperf/progress/progress_models.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class BenchmarkSuiteType(CaseInsensitiveStrEnum):
    """Determines the type of suite to know how to track the progress."""

    SINGLE_PROFILE = "single_profile"
    """An suite with a single profile run."""

    MULTI_PROFILE = "multi_profile"
    """An suite with multiple profile runs. As opposed to a sweep, more than one parameter can be varied. TBD"""

    SINGLE_SWEEP = "single_sweep"
    """An suite with a single sweep over one or more varying parameters. TBD"""

    MULTI_SWEEP = "multi_sweep"
    """An suite with multiple sweep runs over multiple varying parameters. TBD"""

    CUSTOM = "custom"
    """User defined suite type. TBD"""

CUSTOM = 'custom' class-attribute instance-attribute

User defined suite type. TBD

MULTI_PROFILE = 'multi_profile' class-attribute instance-attribute

An suite with multiple profile runs. As opposed to a sweep, more than one parameter can be varied. TBD

MULTI_SWEEP = 'multi_sweep' class-attribute instance-attribute

An suite with multiple sweep runs over multiple varying parameters. TBD

SINGLE_PROFILE = 'single_profile' class-attribute instance-attribute

An suite with a single profile run.

SINGLE_SWEEP = 'single_sweep' class-attribute instance-attribute

An suite with a single sweep over one or more varying parameters. TBD

ProfileCompletionTrigger

Bases: CaseInsensitiveStrEnum

Determines how the profile completion is determined in order to know how to track the progress.

Source code in aiperf/progress/progress_models.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class ProfileCompletionTrigger(CaseInsensitiveStrEnum):
    """Determines how the profile completion is determined in order to know how to track the progress."""

    REQUEST_COUNT = "request_count"
    """The profile will run for a fixed number of requests."""

    TIME_BASED = "time_based"
    """The profile will run for a fixed amount of time."""

    STABILIZATION_BASED = "stabilization_based"
    """The profile will run until the metrics stabilize. TDB"""

    GOODPUT_THRESHOLD = "goodput_threshold"
    """The profile will run until the goodput threshold is met. TDB"""

    CUSTOM = "custom"
    """User defined trigger. TBD"""

CUSTOM = 'custom' class-attribute instance-attribute

User defined trigger. TBD

GOODPUT_THRESHOLD = 'goodput_threshold' class-attribute instance-attribute

The profile will run until the goodput threshold is met. TDB

REQUEST_COUNT = 'request_count' class-attribute instance-attribute

The profile will run for a fixed number of requests.

STABILIZATION_BASED = 'stabilization_based' class-attribute instance-attribute

The profile will run until the metrics stabilize. TDB

TIME_BASED = 'time_based' class-attribute instance-attribute

The profile will run for a fixed amount of time.

ProfileProgress

Bases: BaseModel

State of the profile progress.

Source code in aiperf/progress/progress_models.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class ProfileProgress(BaseModel):
    """State of the profile progress."""

    profile_id: str = Field(..., description="The ID of the profile")

    profile_completion_trigger: ProfileCompletionTrigger = Field(
        default=ProfileCompletionTrigger.REQUEST_COUNT,
        description="The trigger of profile completion",
    )

    start_time_ns: int | None = Field(
        default=None,
        description="The start time of the profile run in nanoseconds. If it has not been started, this will be None.",
    )
    end_time_ns: int | None = Field(
        default=None,
        description="The end time of the profile run in nanoseconds. If it has not been completed, this will be None.",
    )

    total_expected_requests: int | None = Field(
        default=None,
        description="The total number of inference requests to be made. This will be None if the profile completion trigger is not request-based.",
    )
    requests_completed: int = Field(
        default=0,
        description="The number of inference requests completed during the profile run",
    )
    request_errors: int = Field(
        default=0,
        description="The total number of request errors encountered during the profile run",
    )
    successful_requests: int = Field(
        default=0,
        description="The total number of successful requests completed during the profile run",
    )
    requests_processed: int = Field(
        default=0,
        description="The total number of requests processed by the records manager "
        "during the profile run. This can be less than the requests_completed if "
        "the records manager processing requests is slower than the inference requests "
        "are being made.",
    )
    requests_per_second: float | None = Field(
        default=None,
        description="The number of requests completed per second during the profile run",
    )
    processed_per_second: float | None = Field(
        default=None,
        description="The number of requests processed by the records manager per second during the profile run",
    )
    worker_completed: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker request completion counts, keyed by worker service_id during the profile run",
    )
    worker_errors: dict[str, int] = Field(
        default_factory=dict,
        description="Per-worker error counts, keyed by worker service_id during the profile run",
    )
    was_cancelled: bool = Field(
        default=False,
        description="Whether the profile run was cancelled early",
    )
    elapsed_time: float = Field(
        default=0,
        description="The elapsed time of the profile run in seconds",
    )
    eta: float | None = Field(
        default=None,
        description="The estimated time remaining for the profile run in seconds",
    )
    processing_eta: float | None = Field(
        default=None,
        description="The estimated time remaining for processing the records in seconds",
    )
    records: SerializeAsAny[list[MetricResult]] = Field(
        default_factory=list, description="The records of the profile results"
    )
    errors_by_type: list[ErrorDetailsCount] = Field(
        default_factory=list,
        description="A list of the unique error details and their counts",
    )
    is_complete: bool = Field(
        default=False,
        description="Whether the profile run is complete",
    )

ProfileSuiteProgress

Bases: BenchmarkSuiteProgress

State of a profile based suite with 1 or more profile runs.

Source code in aiperf/progress/progress_models.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class ProfileSuiteProgress(BenchmarkSuiteProgress):
    """State of a profile based suite with 1 or more profile runs."""

    profiles: list[ProfileProgress] = Field(
        default_factory=list, description="The state of the profiles in the suite"
    )
    total_profiles: int = Field(default=0, description="The total number of profiles")
    completed_profiles: int = Field(
        default=0, description="The number of completed profiles"
    )
    current_profile_idx: int | None = Field(
        default=None,
        description="The index of the current profile run. If it has not been started, this will be None.",
    )

    def next_profile(self) -> ProfileProgress | None:
        if self.current_profile_idx is None:
            self.current_profile_idx = 0
        else:
            self.current_profile_idx += 1

        if self.current_profile_idx >= len(self.profiles):
            return None

        return self.profiles[self.current_profile_idx]

SweepCompletionTrigger

Bases: CaseInsensitiveStrEnum

Determines how the sweep completion is determined in order to know how to track the progress.

Source code in aiperf/progress/progress_models.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class SweepCompletionTrigger(CaseInsensitiveStrEnum):
    """Determines how the sweep completion is determined in order to know how to track the progress."""

    COMPLETED_PROFILES = "completed_profiles"
    """The sweep will run until all profiles are completed."""

    STABILIZATION_BASED = "stabilization_based"
    """The sweep will run until the metrics stabilize. TDB"""

    GOODPUT_THRESHOLD = "goodput_threshold"
    """The sweep will run until the goodput threshold is met. TDB"""

    CUSTOM = "custom"
    """User defined trigger. TBD"""

COMPLETED_PROFILES = 'completed_profiles' class-attribute instance-attribute

The sweep will run until all profiles are completed.

CUSTOM = 'custom' class-attribute instance-attribute

User defined trigger. TBD

GOODPUT_THRESHOLD = 'goodput_threshold' class-attribute instance-attribute

The sweep will run until the goodput threshold is met. TDB

STABILIZATION_BASED = 'stabilization_based' class-attribute instance-attribute

The sweep will run until the metrics stabilize. TDB

SweepMultiParamOrder

Bases: CaseInsensitiveStrEnum

Determines the order in which the sweep parameters are tested for a multi-parameter sweep. This is only applicable for multi-parameter sweeps.

Source code in aiperf/progress/progress_models.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class SweepMultiParamOrder(CaseInsensitiveStrEnum):
    """Determines the order in which the sweep parameters are tested for a multi-parameter sweep.
    This is only applicable for multi-parameter sweeps."""

    DEPTH_FIRST = "depth_first"
    """The parameters are tested in depth-first order."""

    BREADTH_FIRST = "breadth_first"
    """The parameters are tested in breadth-first order."""

    RANDOM = "random"
    """The parameters are tested in random order. TBD"""

    CUSTOM = "custom"
    """User defined order. TBD"""

BREADTH_FIRST = 'breadth_first' class-attribute instance-attribute

The parameters are tested in breadth-first order.

CUSTOM = 'custom' class-attribute instance-attribute

User defined order. TBD

DEPTH_FIRST = 'depth_first' class-attribute instance-attribute

The parameters are tested in depth-first order.

RANDOM = 'random' class-attribute instance-attribute

The parameters are tested in random order. TBD

SweepParamOrder

Bases: CaseInsensitiveStrEnum

Determines the order in which the sweep parameters are tested.

Source code in aiperf/progress/progress_models.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class SweepParamOrder(CaseInsensitiveStrEnum):
    """Determines the order in which the sweep parameters are tested."""

    ASCENDING = "ascending"
    """The parameters are tested in ascending order."""

    DESCENDING = "descending"
    """The parameters are tested in descending order."""

    RANDOM = "random"
    """The parameters are tested in random order. TBD"""

    CUSTOM = "custom"
    """User defined order. TBD"""

ASCENDING = 'ascending' class-attribute instance-attribute

The parameters are tested in ascending order.

CUSTOM = 'custom' class-attribute instance-attribute

User defined order. TBD

DESCENDING = 'descending' class-attribute instance-attribute

The parameters are tested in descending order.

RANDOM = 'random' class-attribute instance-attribute

The parameters are tested in random order. TBD

SweepParamType

Bases: CaseInsensitiveStrEnum

Determines the type of sweep parameter.

Source code in aiperf/progress/progress_models.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class SweepParamType(CaseInsensitiveStrEnum):
    """Determines the type of sweep parameter."""

    INT = "int"
    """The parameter is an integer."""

    FLOAT = "float"
    """The parameter is a float."""

    STRING = "string"
    """The parameter is a string."""

    BOOLEAN = "boolean"
    """The parameter is a boolean."""

    CUSTOM = "custom"
    """User defined parameter type. TBD"""

BOOLEAN = 'boolean' class-attribute instance-attribute

The parameter is a boolean.

CUSTOM = 'custom' class-attribute instance-attribute

User defined parameter type. TBD

FLOAT = 'float' class-attribute instance-attribute

The parameter is a float.

INT = 'int' class-attribute instance-attribute

The parameter is an integer.

STRING = 'string' class-attribute instance-attribute

The parameter is a string.

SweepProgress

Bases: BaseModel

State of the sweep progress.

Source code in aiperf/progress/progress_models.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class SweepProgress(BaseModel):
    """State of the sweep progress."""

    sweep_id: str = Field(..., description="The ID of the current sweep")
    sweep_completion_trigger: SweepCompletionTrigger = Field(
        default=SweepCompletionTrigger.COMPLETED_PROFILES,
        description="The trigger of sweep completion",
    )
    profiles: list[ProfileProgress] = Field(
        default_factory=list, description="The state of the profiles in the sweep"
    )
    current_profile_idx: int | None = Field(
        default=None,
        description="The index of the current profile. If it has not been started, this will be None.",
    )
    completed_profiles: int = Field(
        default=0, description="The number of completed profiles in the sweep"
    )
    start_time_ns: int | None = Field(
        default=None,
        description="The start time of the sweep in nanoseconds. If it has not been started, this will be None.",
    )
    end_time_ns: int | None = Field(
        default=None,
        description="The end time of the sweep in nanoseconds. If it has not been completed, this will be None.",
    )
    was_cancelled: bool = Field(
        default=False,
        description="Whether the sweep was cancelled early",
    )

    @property
    def current_profile(self) -> ProfileProgress | None:
        if self.current_profile_idx is None:
            return None
        return self.profiles[self.current_profile_idx]

    def next_profile(self) -> ProfileProgress | None:
        if self.current_profile_idx is None:
            self.current_profile_idx = 0
        else:
            self.current_profile_idx += 1

        if self.current_profile_idx >= len(self.profiles):
            return None

        return self.profiles[self.current_profile_idx]

SweepSuiteProgress

Bases: BenchmarkSuiteProgress

State of a sweep based suite with 1 or more sweep runs.

Source code in aiperf/progress/progress_models.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
class SweepSuiteProgress(BenchmarkSuiteProgress):
    """State of a sweep based suite with 1 or more sweep runs."""

    sweeps: list[SweepProgress] = Field(
        default_factory=list, description="The state of the sweeps in the suite"
    )
    total_sweeps: int = Field(default=0, description="The total number of sweeps")
    completed_sweeps: int = Field(
        default=0, description="The number of completed sweeps"
    )
    current_sweep_idx: int | None = Field(
        default=None,
        description="The index of the current sweep. If it has not been started, this will be None.",
    )

    def next_profile(self) -> ProfileProgress | None:
        """Get the next profile to run.

        Returns:
            The next profile to run, or None if there are no more profiles to run.
        """
        if self.current_sweep is None or self.current_sweep.current_profile_idx is None:
            next_sweep = self.next_sweep()
            if next_sweep is None:
                return None
            return next_sweep.next_profile()

        # Try to get the next profile in the current sweep
        next_profile = self.current_sweep.next_profile()
        if next_profile is not None:
            return next_profile

        # If no more profiles in current sweep, move to next sweep
        next_sweep = self.next_sweep()
        if next_sweep is None:
            return None
        return next_sweep.next_profile()

    def next_sweep(self) -> SweepProgress | None:
        """Get the next sweep to run.

        Returns:
            The next sweep to run, or None if there are no more sweeps to run.
        """
        if self.current_sweep_idx is None:
            self.current_sweep_idx = 0
            return self.sweeps[0]
        if self.current_sweep_idx >= len(self.sweeps) - 1:
            return None
        self.current_sweep_idx += 1
        return self.sweeps[self.current_sweep_idx]

next_profile()

Get the next profile to run.

Returns:

Type Description
ProfileProgress | None

The next profile to run, or None if there are no more profiles to run.

Source code in aiperf/progress/progress_models.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def next_profile(self) -> ProfileProgress | None:
    """Get the next profile to run.

    Returns:
        The next profile to run, or None if there are no more profiles to run.
    """
    if self.current_sweep is None or self.current_sweep.current_profile_idx is None:
        next_sweep = self.next_sweep()
        if next_sweep is None:
            return None
        return next_sweep.next_profile()

    # Try to get the next profile in the current sweep
    next_profile = self.current_sweep.next_profile()
    if next_profile is not None:
        return next_profile

    # If no more profiles in current sweep, move to next sweep
    next_sweep = self.next_sweep()
    if next_sweep is None:
        return None
    return next_sweep.next_profile()

next_sweep()

Get the next sweep to run.

Returns:

Type Description
SweepProgress | None

The next sweep to run, or None if there are no more sweeps to run.

Source code in aiperf/progress/progress_models.py
381
382
383
384
385
386
387
388
389
390
391
392
393
def next_sweep(self) -> SweepProgress | None:
    """Get the next sweep to run.

    Returns:
        The next sweep to run, or None if there are no more sweeps to run.
    """
    if self.current_sweep_idx is None:
        self.current_sweep_idx = 0
        return self.sweeps[0]
    if self.current_sweep_idx >= len(self.sweeps) - 1:
        return None
    self.current_sweep_idx += 1
    return self.sweeps[self.current_sweep_idx]

aiperf.services.dataset.composer.base

BaseDatasetComposer

Bases: ABC

Source code in aiperf/services/dataset/composer/base.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BaseDatasetComposer(ABC):
    def __init__(self, config: InputConfig, tokenizer: Tokenizer):
        self.config = config
        self.logger = logging.getLogger(self.__class__.__name__)

        self.prompt_generator = PromptGenerator(config.prompt, tokenizer)
        self.image_generator = ImageGenerator(config.image)
        self.audio_generator = AudioGenerator(config.audio)

    @abstractmethod
    def create_dataset(self) -> list[Conversation]:
        """
        Create a set of conversation objects from the given configuration.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        ...

    @property
    def prefix_prompt_enabled(self) -> bool:
        return self.config.prompt.prefix_prompt.length > 0

create_dataset() abstractmethod

Create a set of conversation objects from the given configuration.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/services/dataset/composer/base.py
26
27
28
29
30
31
32
33
34
@abstractmethod
def create_dataset(self) -> list[Conversation]:
    """
    Create a set of conversation objects from the given configuration.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    ...

aiperf.services.dataset.composer.custom

CustomDatasetComposer

Bases: BaseDatasetComposer

Source code in aiperf/services/dataset/composer/custom.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@ComposerFactory.register(ComposerType.CUSTOM)
class CustomDatasetComposer(BaseDatasetComposer):
    def __init__(self, config: InputConfig, tokenizer: Tokenizer):
        super().__init__(config, tokenizer)

    def create_dataset(self) -> list[Conversation]:
        """Create conversations from a file or directory.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        # TODO: (future) for K8s, we need to transfer file data from SC (across node)
        utils.check_file_exists(self.config.file)

        self._create_loader_instance(self.config.custom_dataset_type)
        dataset = self.loader.load_dataset()
        conversations = self.loader.convert_to_conversations(dataset)
        return conversations

    def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None:
        """Initializes the dataset loader based on the custom dataset type.

        Args:
            dataset_type: The type of custom dataset to create.
        """
        kwargs = {"filename": self.config.file}
        if dataset_type == CustomDatasetType.MOONCAKE_TRACE:
            kwargs["prompt_generator"] = self.prompt_generator

        self.loader = CustomDatasetFactory.create_instance(dataset_type, **kwargs)

create_dataset()

Create conversations from a file or directory.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/services/dataset/composer/custom.py
18
19
20
21
22
23
24
25
26
27
28
29
30
def create_dataset(self) -> list[Conversation]:
    """Create conversations from a file or directory.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    # TODO: (future) for K8s, we need to transfer file data from SC (across node)
    utils.check_file_exists(self.config.file)

    self._create_loader_instance(self.config.custom_dataset_type)
    dataset = self.loader.load_dataset()
    conversations = self.loader.convert_to_conversations(dataset)
    return conversations

aiperf.services.dataset.composer.synthetic

SyntheticDatasetComposer

Bases: BaseDatasetComposer

Source code in aiperf/services/dataset/composer/synthetic.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@ComposerFactory.register(ComposerType.SYNTHETIC)
class SyntheticDatasetComposer(BaseDatasetComposer):
    def __init__(self, config: InputConfig, tokenizer: Tokenizer):
        super().__init__(config, tokenizer)

        if (
            not self.include_prompt
            and not self.include_image
            and not self.include_audio
        ):
            raise ValueError(
                "All synthetic data are disabled. "
                "Please enable at least one of prompt, image, or audio by "
                "setting the mean to a positive value."
            )

    def create_dataset(self) -> list[Conversation]:
        """Create a synthetic conversation dataset from the given configuration.

        It generates a set of conversations with a varying number of turns,
        where each turn contains synthetic text, image, and audio payloads.

        Returns:
            list[Conversation]: A list of conversation objects.
        """
        conversations = []
        for _ in range(self.config.conversation.num):
            conversation = Conversation(session_id=str(uuid.uuid4()))

            num_turns = utils.sample_positive_normal_integer(
                self.config.conversation.turn.mean,
                self.config.conversation.turn.stddev,
            )
            self.logger.debug("Creating conversation with %d turns", num_turns)

            for turn_idx in range(num_turns):
                turn = self._create_turn(is_first=(turn_idx == 0))
                conversation.turns.append(turn)
            conversations.append(conversation)
        return conversations

    def _create_turn(self, is_first: bool) -> Turn:
        """Create a turn object that contains synthetic payloads to send.

        It generates multi-modal data (e.g. text, image, audio) using synthetic
        generators and also the delay between turns.

        Args:
            is_first: Whether the turn is the first turn in the conversation.

        Returns:
            Turn: A dataset representation of a single turn.
        """
        turn = Turn()

        if self.include_prompt:
            turn.texts.append(self._generate_text_payloads(is_first))
        if self.include_image:
            turn.images.append(self._generate_image_payloads())
        if self.include_audio:
            turn.audios.append(self._generate_audio_payloads())

        # Add randomized delays between each turn. Skip if first turn.
        if not is_first:
            turn.delay = utils.sample_positive_normal_integer(
                self.config.conversation.turn.delay.mean,
                self.config.conversation.turn.delay.stddev,
            )

        if not turn.texts and not turn.images and not turn.audios:
            self.logger.warning(
                "There were no synthetic payloads generated. "
                "Please enable at least one of prompt, image, or audio by "
                "setting the mean to a positive value."
            )

        return turn

    def _generate_text_payloads(self, is_first: bool) -> Text:
        """Generate synthetic text payloads.

        If the turn is the first turn in the conversation, it could add a prefix prompt
        to the prompt.

        Args:
            is_first: Whether the turn is the first turn in the conversation.

        Returns:
            Text: A text payload object.
        """
        text = Text(name="text")
        for _ in range(self.config.prompt.batch_size):
            prompt = self.prompt_generator.generate(
                mean=self.config.prompt.input_tokens.mean,
                stddev=self.config.prompt.input_tokens.stddev,
            )

            if self.prefix_prompt_enabled and is_first:
                # TODO: Rename
                prefix_prompt = self.prompt_generator.get_random_prefix_prompt()
                prompt = f"{prefix_prompt} {prompt}"

            text.contents.append(prompt)
        return text

    def _generate_image_payloads(self) -> Image:
        """
        Generate synthetic images if the image width and height are specified.

        Returns:
            Image: An image payload object.
        """
        image = Image(name="image_url")
        for _ in range(self.config.image.batch_size):
            data = self.image_generator.generate()
            image.contents.append(data)
        return image

    def _generate_audio_payloads(self) -> Audio:
        """
        Generate synthetic audios if the audio length is specified.

        Returns:
            Audio: An audio payload object.
        """
        audio = Audio(name="input_audio")
        for _ in range(self.config.audio.batch_size):
            data = self.audio_generator.generate()
            audio.contents.append(data)
        return audio

    @property
    def include_prompt(self) -> bool:
        return self.config.prompt.input_tokens.mean > 0

    @property
    def include_image(self) -> bool:
        return self.config.image.width.mean > 0 and self.config.image.height.mean > 0

    @property
    def include_audio(self) -> bool:
        return self.config.audio.length.mean > 0

create_dataset()

Create a synthetic conversation dataset from the given configuration.

It generates a set of conversations with a varying number of turns, where each turn contains synthetic text, image, and audio payloads.

Returns:

Type Description
list[Conversation]

list[Conversation]: A list of conversation objects.

Source code in aiperf/services/dataset/composer/synthetic.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def create_dataset(self) -> list[Conversation]:
    """Create a synthetic conversation dataset from the given configuration.

    It generates a set of conversations with a varying number of turns,
    where each turn contains synthetic text, image, and audio payloads.

    Returns:
        list[Conversation]: A list of conversation objects.
    """
    conversations = []
    for _ in range(self.config.conversation.num):
        conversation = Conversation(session_id=str(uuid.uuid4()))

        num_turns = utils.sample_positive_normal_integer(
            self.config.conversation.turn.mean,
            self.config.conversation.turn.stddev,
        )
        self.logger.debug("Creating conversation with %d turns", num_turns)

        for turn_idx in range(num_turns):
            turn = self._create_turn(is_first=(turn_idx == 0))
            conversation.turns.append(turn)
        conversations.append(conversation)
    return conversations

aiperf.services.dataset.dataset_manager

DatasetManager

Bases: BaseComponentService

The DatasetManager primary responsibility is to manage the data generation or acquisition. For synthetic generation, it contains the code to generate the prompts or tokens. It will have an API for dataset acquisition of a dataset if available in a remote repository or database.

Source code in aiperf/services/dataset/dataset_manager.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
@ServiceFactory.register(ServiceType.DATASET_MANAGER)
class DatasetManager(BaseComponentService):
    """
    The DatasetManager primary responsibility is to manage the data generation or acquisition.
    For synthetic generation, it contains the code to generate the prompts or tokens.
    It will have an API for dataset acquisition of a dataset if available in a remote repository or database.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.debug("Dataset manager __init__")
        self.user_config = user_config
        self.tokenizer: Tokenizer | None = None
        self.dataset: dict[str, Conversation] = {}  # session ID -> Conversation mapping
        self.router_reply_client: ReplyClientProtocol = self.comms.create_reply_client(
            CommunicationClientAddressType.DATASET_MANAGER_PROXY_BACKEND
        )
        self.dataset_configured = asyncio.Event()

    @property
    def service_type(self) -> ServiceType:
        """The type of service."""
        return ServiceType.DATASET_MANAGER

    @on_init
    async def _initialize(self) -> None:
        """Initialize dataset manager-specific components."""
        self.debug(lambda: f"Initializing dataset manager {self.service_id}")

        self.router_reply_client.register_request_handler(
            service_id=self.service_id,
            message_type=MessageType.CONVERSATION_REQUEST,
            handler=self._handle_conversation_request,
        )
        self.router_reply_client.register_request_handler(
            service_id=self.service_id,
            message_type=MessageType.DATASET_TIMING_REQUEST,
            handler=self._handle_dataset_timing_request,
        )
        self.router_reply_client.register_request_handler(
            service_id=self.service_id,
            message_type=MessageType.CONVERSATION_TURN_REQUEST,
            handler=self._handle_conversation_turn_request,
        )

        self.debug(lambda: f"Dataset manager {self.service_id} initialized")

    async def _configure_dataset(self) -> None:
        if self.user_config is None:
            raise self._service_error("User config is required for dataset manager")

        if self.user_config.input.file:
            composer_type = ComposerType.CUSTOM
            self.debug(
                lambda: f"Detected input file '{self.user_config.input.file}'. Setting the composer type to {ComposerType.CUSTOM}."
            )
        else:
            composer_type = ComposerType.SYNTHETIC
            self.debug(
                lambda: f"No input file detected. Setting the composer type to {ComposerType.SYNTHETIC}."
            )

        tokenizer_name = self.user_config.tokenizer.name
        if tokenizer_name is None:
            # TODO: What do we do if there are multiple models?
            # How will we know which tokenizer to use?
            tokenizer_name = self.user_config.model_names[0]

        tokenizer = Tokenizer.from_pretrained(
            tokenizer_name,
            trust_remote_code=self.user_config.tokenizer.trust_remote_code,
            revision=self.user_config.tokenizer.revision,
        )
        composer = ComposerFactory.create_instance(
            composer_type,
            config=self.user_config.input,
            tokenizer=tokenizer,
        )
        conversations = composer.create_dataset()
        self.dataset = {conv.session_id: conv for conv in conversations}

        self.dataset_configured.set()
        await self.pub_client.publish(
            DatasetConfiguredNotification(
                service_id=self.service_id,
            ),
        )

    @on_configure
    async def _configure(self, message: Message) -> None:
        """Configure the dataset manager."""
        # TODO: This is a temporary hack with the changes to user config loading
        self.dataset_configured.clear()
        await self._configure_dataset()

    async def _handle_conversation_request(
        self, message: ConversationRequestMessage
    ) -> ConversationResponseMessage:
        """Handle a conversation request."""
        self.debug(lambda: f"Handling conversation request: {message}")

        # Wait for the dataset to be configured if it is not already
        if not self.dataset_configured.is_set():
            self.debug(
                "Dataset not configured. Waiting for dataset to be configured..."
            )
            await asyncio.wait_for(
                self.dataset_configured.wait(), timeout=DATASET_CONFIGURATION_TIMEOUT
            )

        if not self.dataset:
            raise self._service_error(
                "Dataset is empty and must be configured before handling requests.",
            )

        if message.conversation_id is None:
            return self._return_any_conversation(
                request_id=message.request_id,
            )
        else:
            return self._return_conversation_by_id(
                request_id=message.request_id,
                conversation_id=message.conversation_id,
            )

    def _return_any_conversation(
        self, request_id: str | None
    ) -> ConversationResponseMessage:
        """Return any conversation from the dataset based on the user specified method."""

        # TODO: Implement the user specified method (random, round robin, etc.)
        conversation = random.choice(list(self.dataset.values()))
        self.debug(lambda: f"Sending random conversation response: {conversation}")
        return ConversationResponseMessage(
            service_id=self.service_id,
            request_id=request_id,
            conversation=conversation,
        )

    def _return_conversation_by_id(
        self, request_id: str | None, conversation_id: str
    ) -> ConversationResponseMessage:
        """Return a conversation if it exists, otherwise raise an error."""

        if conversation_id not in self.dataset:
            raise self._service_error(
                f"Conversation {conversation_id} not found in dataset.",
            )

        conversation = self.dataset[conversation_id]
        self.debug(lambda: f"Sending conversation response: {conversation}")
        return ConversationResponseMessage(
            service_id=self.service_id,
            request_id=request_id,
            conversation=conversation,
        )

    async def _handle_conversation_turn_request(
        self, message: ConversationTurnRequestMessage
    ) -> ConversationTurnResponseMessage:
        """Handle a turn request."""
        self.debug(lambda: f"Handling turn request: {message}")

        if message.conversation_id not in self.dataset:
            raise self._service_error(
                f"Conversation {message.conversation_id} not found in dataset.",
            )

        conversation = self.dataset[message.conversation_id]
        if message.turn_index >= len(conversation.turns):
            raise self._service_error(
                f"Turn index {message.turn_index} is out of range for conversation {message.conversation_id}.",
            )

        turn = conversation.turns[message.turn_index]

        self.debug(lambda: f"Sending turn response: {turn}")
        return ConversationTurnResponseMessage(
            service_id=self.service_id,
            request_id=message.request_id,
            turn=turn,
        )

    async def _handle_dataset_timing_request(
        self, message: DatasetTimingRequest
    ) -> DatasetTimingResponse:
        """Handle a dataset timing request."""
        self.debug(lambda: f"Handling dataset timing request: {message}")
        if not self.dataset:
            raise self._service_error(
                "Dataset is empty and must be configured before handling timing requests.",
            )

        timing_dataset = []
        for conversation_id, conversation in self.dataset.items():
            for turn in conversation.turns:
                timing_dataset.append((turn.timestamp, conversation_id))

        return DatasetTimingResponse(
            service_id=self.service_id,
            request_id=message.request_id,
            timing_data=timing_dataset,
        )

service_type property

The type of service.

main()

Main entry point for the dataset manager.

Source code in aiperf/services/dataset/dataset_manager.py
252
253
254
255
256
257
def main() -> None:
    """Main entry point for the dataset manager."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(DatasetManager)

aiperf.services.dataset.generator.audio

AudioGenerator

Bases: BaseGenerator

A class for generating synthetic audio data.

This class provides methods to create audio samples with specified characteristics such as format (WAV, MP3), length, sampling rate, bit depth, and number of channels. It supports validation of audio parameters to ensure compatibility with chosen formats.

Source code in aiperf/services/dataset/generator/audio.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class AudioGenerator(BaseGenerator):
    """
    A class for generating synthetic audio data.

    This class provides methods to create audio samples with specified
    characteristics such as format (WAV, MP3), length, sampling rate,
    bit depth, and number of channels. It supports validation of audio
    parameters to ensure compatibility with chosen formats.
    """

    def __init__(self, config: AudioConfig):
        super().__init__()
        self.config = config

    def _validate_sampling_rate(
        self, sampling_rate_hz: int, audio_format: AudioFormat
    ) -> None:
        """
        Validate sampling rate for the given output format.

        Args:
            sampling_rate_hz: Sampling rate in Hz
            audio_format: Audio format

        Raises:
            ConfigurationError: If sampling rate is not supported for the given format
        """
        if (
            audio_format == AudioFormat.MP3
            and sampling_rate_hz not in MP3_SUPPORTED_SAMPLE_RATES
        ):
            supported_rates = sorted(MP3_SUPPORTED_SAMPLE_RATES)
            raise ConfigurationError(
                f"MP3 format only supports the following sample rates (in Hz): {supported_rates}. "
                f"Got {sampling_rate_hz} Hz. Please choose a supported rate from the list."
            )

    def _validate_bit_depth(self, bit_depth: int) -> None:
        """
        Validate bit depth is supported.

        Args:
            bit_depth: Bit depth in bits

        Raises:
            ConfigurationError: If bit depth is not supported
        """
        if bit_depth not in SUPPORTED_BIT_DEPTHS:
            supported_depths = sorted(SUPPORTED_BIT_DEPTHS.keys())
            raise ConfigurationError(
                f"Unsupported bit depth: {bit_depth}. "
                f"Supported bit depths are: {supported_depths}"
            )

    def generate(self, *args, **kwargs) -> str:
        """Generate audio data with specified parameters.

        Returns:
            Data URI containing base64-encoded audio data with format specification

        Raises:
            ConfigurationError: If any of the following conditions are met:
                - audio length is less than 0.01 seconds
                - channels is not 1 (mono) or 2 (stereo)
                - sampling rate is not supported for MP3 format
                - bit depth is not supported (must be 8, 16, 24, or 32)
                - audio format is not supported (must be 'wav' or 'mp3')
        """
        if self.config.num_channels not in (1, 2):
            raise ConfigurationError(
                "Only mono (1) and stereo (2) channels are supported"
            )

        if self.config.length.mean < 0.01:
            raise ConfigurationError("Audio length must be greater than 0.01 seconds")

        # Sample audio length (in seconds) using rejection sampling
        audio_length = utils.sample_normal(
            self.config.length.mean, self.config.length.stddev, lower=0.01
        )

        # Randomly select sampling rate and bit depth
        sampling_rate_hz = int(
            np.random.choice(self.config.sample_rates) * 1000
        )  # Convert kHz to Hz
        bit_depth = np.random.choice(self.config.depths)

        # Validate sampling rate and bit depth
        self._validate_sampling_rate(sampling_rate_hz, self.config.format)
        self._validate_bit_depth(bit_depth)

        # Generate synthetic audio data (gaussian noise)
        num_samples = int(audio_length * sampling_rate_hz)
        audio_data = np.random.normal(
            0,
            0.3,
            (
                (num_samples, self.config.num_channels)
                if self.config.num_channels > 1
                else num_samples
            ),
        )

        # Ensure the signal is within [-1, 1] range
        audio_data = np.clip(audio_data, -1, 1)

        # Scale to the appropriate bit depth range
        max_val = 2 ** (bit_depth - 1) - 1
        numpy_type, _ = SUPPORTED_BIT_DEPTHS[bit_depth]
        audio_data = (audio_data * max_val).astype(numpy_type)

        # Write audio using soundfile
        output_buffer = io.BytesIO()

        # Select appropriate subtype based on format
        if self.config.format == AudioFormat.MP3:
            subtype = "MPEG_LAYER_III"
        elif self.config.format == AudioFormat.WAV:
            _, subtype = SUPPORTED_BIT_DEPTHS[bit_depth]
        else:
            raise ConfigurationError(
                f"Unsupported audio format: {self.config.format}. "
                f"Supported formats are: {AudioFormat.WAV.name}, {AudioFormat.MP3.name}"
            )

        sf.write(
            output_buffer,
            audio_data,
            sampling_rate_hz,
            format=self.config.format,
            subtype=subtype,
        )
        audio_bytes = output_buffer.getvalue()

        # Encode to base64 with data URI scheme: "{format},{data}"
        base64_data = base64.b64encode(audio_bytes).decode("utf-8")
        return f"{self.config.format.lower()},{base64_data}"

generate(*args, **kwargs)

Generate audio data with specified parameters.

Returns:

Type Description
str

Data URI containing base64-encoded audio data with format specification

Raises:

Type Description
ConfigurationError

If any of the following conditions are met: - audio length is less than 0.01 seconds - channels is not 1 (mono) or 2 (stereo) - sampling rate is not supported for MP3 format - bit depth is not supported (must be 8, 16, 24, or 32) - audio format is not supported (must be 'wav' or 'mp3')

Source code in aiperf/services/dataset/generator/audio.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def generate(self, *args, **kwargs) -> str:
    """Generate audio data with specified parameters.

    Returns:
        Data URI containing base64-encoded audio data with format specification

    Raises:
        ConfigurationError: If any of the following conditions are met:
            - audio length is less than 0.01 seconds
            - channels is not 1 (mono) or 2 (stereo)
            - sampling rate is not supported for MP3 format
            - bit depth is not supported (must be 8, 16, 24, or 32)
            - audio format is not supported (must be 'wav' or 'mp3')
    """
    if self.config.num_channels not in (1, 2):
        raise ConfigurationError(
            "Only mono (1) and stereo (2) channels are supported"
        )

    if self.config.length.mean < 0.01:
        raise ConfigurationError("Audio length must be greater than 0.01 seconds")

    # Sample audio length (in seconds) using rejection sampling
    audio_length = utils.sample_normal(
        self.config.length.mean, self.config.length.stddev, lower=0.01
    )

    # Randomly select sampling rate and bit depth
    sampling_rate_hz = int(
        np.random.choice(self.config.sample_rates) * 1000
    )  # Convert kHz to Hz
    bit_depth = np.random.choice(self.config.depths)

    # Validate sampling rate and bit depth
    self._validate_sampling_rate(sampling_rate_hz, self.config.format)
    self._validate_bit_depth(bit_depth)

    # Generate synthetic audio data (gaussian noise)
    num_samples = int(audio_length * sampling_rate_hz)
    audio_data = np.random.normal(
        0,
        0.3,
        (
            (num_samples, self.config.num_channels)
            if self.config.num_channels > 1
            else num_samples
        ),
    )

    # Ensure the signal is within [-1, 1] range
    audio_data = np.clip(audio_data, -1, 1)

    # Scale to the appropriate bit depth range
    max_val = 2 ** (bit_depth - 1) - 1
    numpy_type, _ = SUPPORTED_BIT_DEPTHS[bit_depth]
    audio_data = (audio_data * max_val).astype(numpy_type)

    # Write audio using soundfile
    output_buffer = io.BytesIO()

    # Select appropriate subtype based on format
    if self.config.format == AudioFormat.MP3:
        subtype = "MPEG_LAYER_III"
    elif self.config.format == AudioFormat.WAV:
        _, subtype = SUPPORTED_BIT_DEPTHS[bit_depth]
    else:
        raise ConfigurationError(
            f"Unsupported audio format: {self.config.format}. "
            f"Supported formats are: {AudioFormat.WAV.name}, {AudioFormat.MP3.name}"
        )

    sf.write(
        output_buffer,
        audio_data,
        sampling_rate_hz,
        format=self.config.format,
        subtype=subtype,
    )
    audio_bytes = output_buffer.getvalue()

    # Encode to base64 with data URI scheme: "{format},{data}"
    base64_data = base64.b64encode(audio_bytes).decode("utf-8")
    return f"{self.config.format.lower()},{base64_data}"

aiperf.services.dataset.generator.base

BaseGenerator

Bases: ABC

Abstract base class for all data generators.

Provides a consistent interface for generating synthetic data while allowing each generator type to use its own specific configuration and runtime parameters.

Source code in aiperf/services/dataset/generator/base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseGenerator(ABC):
    """Abstract base class for all data generators.

    Provides a consistent interface for generating synthetic data while allowing
    each generator type to use its own specific configuration and runtime parameters.
    """

    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)

    @abstractmethod
    def generate(self, *args, **kwargs) -> str:
        """Generate synthetic data.

        Args:
            *args: Variable length argument list (subclass-specific)
            **kwargs: Arbitrary keyword arguments (subclass-specific)

        Returns:
            Generated data as a string (could be text, base64 encoded media, etc.)
        """
        pass

generate(*args, **kwargs) abstractmethod

Generate synthetic data.

Parameters:

Name Type Description Default
*args

Variable length argument list (subclass-specific)

()
**kwargs

Arbitrary keyword arguments (subclass-specific)

{}

Returns:

Type Description
str

Generated data as a string (could be text, base64 encoded media, etc.)

Source code in aiperf/services/dataset/generator/base.py
18
19
20
21
22
23
24
25
26
27
28
29
@abstractmethod
def generate(self, *args, **kwargs) -> str:
    """Generate synthetic data.

    Args:
        *args: Variable length argument list (subclass-specific)
        **kwargs: Arbitrary keyword arguments (subclass-specific)

    Returns:
        Generated data as a string (could be text, base64 encoded media, etc.)
    """
    pass

aiperf.services.dataset.generator.image

ImageGenerator

Bases: BaseGenerator

A class that generates images from source images.

This class provides methods to create synthetic images by resizing source images (located in the 'assets/source_images' directory) to specified dimensions and converting them to a chosen image format (e.g., PNG, JPEG). The dimensions can be randomized based on mean and standard deviation values.

Source code in aiperf/services/dataset/generator/image.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ImageGenerator(BaseGenerator):
    """A class that generates images from source images.

    This class provides methods to create synthetic images by resizing
    source images (located in the 'assets/source_images' directory)
    to specified dimensions and converting them to a chosen image format (e.g., PNG, JPEG).
    The dimensions can be randomized based on mean and standard deviation values.
    """

    def __init__(self, config: ImageConfig):
        super().__init__()
        self.config = config

    def generate(self, *args, **kwargs) -> str:
        """Generate an image with the configured parameters.

        Returns:
            A base64 encoded string of the generated image.
        """
        image_format = self.config.format
        if image_format == ImageFormat.RANDOM:
            image_format = random.choice(
                [f for f in ImageFormat if f != ImageFormat.RANDOM]
            )

        width = utils.sample_positive_normal_integer(
            self.config.width.mean, self.config.width.stddev
        )
        height = utils.sample_positive_normal_integer(
            self.config.height.mean, self.config.height.stddev
        )

        self.logger.debug(
            "Generating image with width=%d, height=%d",
            width,
            height,
        )

        image = self._sample_source_image()
        image = image.resize(size=(width, height))
        base64_image = utils.encode_image(image, image_format)
        return f"data:image/{image_format.name.lower()};base64,{base64_image}"

    def _sample_source_image(self):
        """Sample one image among the source images.

        Returns:
            A PIL Image object randomly selected from the source images.
        """
        filepath = Path(__file__).parent.resolve() / "assets" / "source_images" / "*"
        filenames = glob.glob(str(filepath))
        if not filenames:
            raise ValueError(f"No source images found in '{filepath}'")
        return Image.open(random.choice(filenames))

generate(*args, **kwargs)

Generate an image with the configured parameters.

Returns:

Type Description
str

A base64 encoded string of the generated image.

Source code in aiperf/services/dataset/generator/image.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def generate(self, *args, **kwargs) -> str:
    """Generate an image with the configured parameters.

    Returns:
        A base64 encoded string of the generated image.
    """
    image_format = self.config.format
    if image_format == ImageFormat.RANDOM:
        image_format = random.choice(
            [f for f in ImageFormat if f != ImageFormat.RANDOM]
        )

    width = utils.sample_positive_normal_integer(
        self.config.width.mean, self.config.width.stddev
    )
    height = utils.sample_positive_normal_integer(
        self.config.height.mean, self.config.height.stddev
    )

    self.logger.debug(
        "Generating image with width=%d, height=%d",
        width,
        height,
    )

    image = self._sample_source_image()
    image = image.resize(size=(width, height))
    base64_image = utils.encode_image(image, image_format)
    return f"data:image/{image_format.name.lower()};base64,{base64_image}"

aiperf.services.dataset.generator.prompt

PromptGenerator

Bases: BaseGenerator

A class for generating synthetic prompts from a text corpus.

This class loads a text corpus (e.g., Shakespearean text), tokenizes it, and uses the tokenized corpus to generate synthetic prompts of specified lengths. It supports generating prompts with a target number of tokens (with optional randomization around a mean and standard deviation) and can reuse previously generated token blocks to optimize generation for certain use cases. It also allows for the creation of a pool of prefix prompts that can be randomly selected.

Source code in aiperf/services/dataset/generator/prompt.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class PromptGenerator(BaseGenerator):
    """A class for generating synthetic prompts from a text corpus.

    This class loads a text corpus (e.g., Shakespearean text), tokenizes it,
    and uses the tokenized corpus to generate synthetic prompts of specified
    lengths. It supports generating prompts with a target number of tokens
    (with optional randomization around a mean and standard deviation) and
    can reuse previously generated token blocks to optimize generation for
    certain use cases. It also allows for the creation of a pool of prefix
    prompts that can be randomly selected.
    """

    def __init__(self, config: PromptConfig, tokenizer: Tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        self._tokenized_corpus = None
        self._corpus_size = 0
        self._prefix_prompts: list[str] = []

        # Cached prompts: block ID -> list of tokens
        self._cache: dict[int, list[int]] = {}

        # TODO: move this under initialize() method
        # Initialize corpus if not already done
        if self._tokenized_corpus is None:
            self._initialize_corpus()

        # Initialize prefix prompts pool if the pool size > 0
        if self.config.prefix_prompt.pool_size > 0:
            self._create_prefix_prompt_pool()

    def _initialize_corpus(self) -> None:
        """Load and tokenize the corpus once, storing it for reuse."""
        corpus_path = Path(__file__).parent / DEFAULT_CORPUS_FILE

        with open(corpus_path) as f:
            lines = f.readlines()

        def tokenize_chunk(chunk):
            cleaned_text = " ".join(line.strip() for line in chunk if line.strip())
            tokens = self.tokenizer.encode(cleaned_text)
            return tokens

        num_threads = os.cpu_count()
        if num_threads is None:
            num_threads = 4

        # Ensure chunk_size is at least 1 to avoid division by zero in range()
        chunk_size = max(1, len(lines) // num_threads)
        chunks = [lines[i : i + chunk_size] for i in range(0, len(lines), chunk_size)]

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            tokenized_chunks = list(executor.map(tokenize_chunk, chunks))

        self._tokenized_corpus = [
            token for chunk in tokenized_chunks for token in chunk
        ]
        self._corpus_size = len(self._tokenized_corpus)
        self.logger.debug("Initialized corpus with %d tokens", self._corpus_size)

    def _create_prefix_prompt_pool(self) -> None:
        """Generate a pool of prefix prompts to sample from."""
        if self._tokenized_corpus is None:
            raise NotInitializedError("Tokenized corpus is not initialized.")

        self._prefix_prompts = [
            self._generate_prompt(self.config.prefix_prompt.length)
            for _ in range(self.config.prefix_prompt.pool_size)
        ]
        self.logger.debug(
            "Initialized prefix prompts pool with %d prompts",
            len(self._prefix_prompts),
        )

    def generate(
        self,
        mean: int | None = None,
        stddev: int | None = None,
        hash_ids: list[int] | None = None,
    ) -> str:
        """Generate a synthetic prompt with the configuration parameters.

        Args:
            mean: The mean of the normal distribution.
            stddev: The standard deviation of the normal distribution.
            hash_ids: A list of hash indices used for token reuse.

        Returns:
            A synthetic prompt as a string.
        """
        if hash_ids:
            return self._generate_cached_prompt(
                mean, hash_ids, self.config.input_tokens.block_size
            )

        num_tokens = utils.sample_positive_normal_integer(mean, stddev)
        return self._generate_prompt(num_tokens)

    def _generate_prompt(self, num_tokens: int) -> str:
        """Generate a prompt containing exactly `num_tokens` number of tokens.

        Args:
            num_tokens: Number of tokens required in the prompt.

        Returns:
            A synthetic prompt as a string.
        """
        return self.tokenizer.decode(self._sample_tokens(num_tokens))

    def _generate_cached_prompt(
        self,
        num_tokens: int,
        hash_ids: list[int],
        block_size: int,
    ) -> str:
        """
        Generate a prompt containing exactly `num_tokens` by reusing previously generated prompts
        stored in `_cache`. Each hash index in `hash_ids` corresponds to a block of
        `block_size` tokens. If a hash index is found in `_cache`, its stored prompt is reused.
        Otherwise, a new prompt is generated using `_generate_prompt()` and stored in `_cache`.

        Args:
            num_tokens: The number of tokens required in the prompt.
            hash_ids: A list of hash IDs to use for token reuse.
            block_size: The number of tokens allocated per hash block.

        Returns:
            str: A synthetic prompt as a string.

        Raises:
            ConfigurationError: If the input parameters are not compatible.
        """
        final_prompt: list[int] = []
        current_block_size = block_size

        # Sanity check the final block size
        final_block_size = num_tokens - ((len(hash_ids) - 1) * block_size)
        if final_block_size <= 0 or block_size < final_block_size:
            raise ConfigurationError(
                f"Input length: {num_tokens}, Hash IDs: {hash_ids}, Block size: {block_size} "
                f"are not compatible. The final hash block size: {final_block_size} must be "
                f"greater than 0 and less than or equal to {block_size}."
            )

        for index, hash_id in enumerate(hash_ids):
            # For the last hash ID, use the remaining tokens as the block size
            if index == len(hash_ids) - 1:
                current_block_size = final_block_size

            if hash_id not in self._cache:
                # To ensure that the prompt doesn't merge chunks, we pop the last token
                # and insert the bos token at the beginning. Length is maintained and
                # the prompt generates the expected number of tokens.
                prompt_tokens: list[int] = self._sample_tokens(current_block_size)
                prompt_tokens.pop(0)
                prompt_tokens.insert(0, self.tokenizer.bos_token_id)
                self._cache[hash_id] = prompt_tokens  # store to cache

            final_prompt.extend(self._cache[hash_id])

        return self.tokenizer.decode(final_prompt, skip_special_tokens=False)

    def _sample_tokens(self, num_tokens: int) -> list[int]:
        """Generate a list of token IDs containing exactly `num_tokens` number of tokens
        using the preloaded tokenized corpus.

        Args:
            num_tokens: Number of tokens required in the prompt.

        Returns:
            A list of token IDs.

        Raises:
            NotInitializedError: If the tokenized corpus is not initialized
        """
        if not self._tokenized_corpus:
            raise NotInitializedError("Tokenized corpus is not initialized.")
        if num_tokens > self._corpus_size:
            logger.warning(
                f"Requested prompt length {num_tokens} is longer than the corpus. "
                f"Returning a prompt of length {self._corpus_size}."
            )

        start_idx = random.randrange(self._corpus_size)

        end_idx = start_idx + num_tokens
        prompt_tokens = self._tokenized_corpus[start_idx:end_idx]
        if end_idx > self._corpus_size:
            prompt_tokens += self._tokenized_corpus[: end_idx - self._corpus_size]

        self.logger.debug("Sampled %d tokens from corpus", len(prompt_tokens))
        return prompt_tokens

    def get_random_prefix_prompt(self) -> str:
        """
        Fetch a random prefix prompt from the pool.

        Returns:
            A random prefix prompt.

        Raises:
            InvalidStateError: If the prefix prompts pool is empty.
        """
        if not self._prefix_prompts:
            raise InvalidStateError(
                "Attempted to sample a prefix prompt but the prefix prompts pool is empty. "
                "Please ensure that the prefix prompts pool is initialized."
            )
        return random.choice(self._prefix_prompts)

generate(mean=None, stddev=None, hash_ids=None)

Generate a synthetic prompt with the configuration parameters.

Parameters:

Name Type Description Default
mean int | None

The mean of the normal distribution.

None
stddev int | None

The standard deviation of the normal distribution.

None
hash_ids list[int] | None

A list of hash indices used for token reuse.

None

Returns:

Type Description
str

A synthetic prompt as a string.

Source code in aiperf/services/dataset/generator/prompt.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def generate(
    self,
    mean: int | None = None,
    stddev: int | None = None,
    hash_ids: list[int] | None = None,
) -> str:
    """Generate a synthetic prompt with the configuration parameters.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.
        hash_ids: A list of hash indices used for token reuse.

    Returns:
        A synthetic prompt as a string.
    """
    if hash_ids:
        return self._generate_cached_prompt(
            mean, hash_ids, self.config.input_tokens.block_size
        )

    num_tokens = utils.sample_positive_normal_integer(mean, stddev)
    return self._generate_prompt(num_tokens)

get_random_prefix_prompt()

Fetch a random prefix prompt from the pool.

Returns:

Type Description
str

A random prefix prompt.

Raises:

Type Description
InvalidStateError

If the prefix prompts pool is empty.

Source code in aiperf/services/dataset/generator/prompt.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def get_random_prefix_prompt(self) -> str:
    """
    Fetch a random prefix prompt from the pool.

    Returns:
        A random prefix prompt.

    Raises:
        InvalidStateError: If the prefix prompts pool is empty.
    """
    if not self._prefix_prompts:
        raise InvalidStateError(
            "Attempted to sample a prefix prompt but the prefix prompts pool is empty. "
            "Please ensure that the prefix prompts pool is initialized."
        )
    return random.choice(self._prefix_prompts)

aiperf.services.dataset.loader.models

CustomData = Annotated[SingleTurn | MooncakeTrace | MultiTurn, Field(discriminator='type')] module-attribute

A union type of all custom data types.

MooncakeTrace

Bases: AIPerfBaseModel

Defines the schema for Mooncake trace data.

See https://github.com/kvcache-ai/Mooncake for more details.

Example:

{"timestamp": 1000, "input_length": 10, "output_length": 4, "hash_ids": [123, 456]}
Source code in aiperf/services/dataset/loader/models.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class MooncakeTrace(AIPerfBaseModel):
    """Defines the schema for Mooncake trace data.

    See https://github.com/kvcache-ai/Mooncake for more details.

    Example:
    ```json
    {"timestamp": 1000, "input_length": 10, "output_length": 4, "hash_ids": [123, 456]}
    ```
    """

    type: Literal[CustomDatasetType.MOONCAKE_TRACE] = CustomDatasetType.MOONCAKE_TRACE

    input_length: int = Field(..., description="The input sequence length of a request")
    output_length: int = Field(
        ..., description="The output sequence length of a request"
    )
    hash_ids: list[int] = Field(..., description="The hash ids of a request")
    timestamp: int = Field(..., description="The timestamp of a request")

MultiTurn

Bases: AIPerfBaseModel

Defines the schema for multi-turn conversations.

The multi-turn custom dataset - supports multi-modal data (e.g. text, image, audio) - supports multi-turn features (e.g. delay, sessions, etc.) - supports client-side batching for each data (e.g. batch size > 1)

Source code in aiperf/services/dataset/loader/models.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class MultiTurn(AIPerfBaseModel):
    """Defines the schema for multi-turn conversations.

    The multi-turn custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports multi-turn features (e.g. delay, sessions, etc.)
      - supports client-side batching for each data (e.g. batch size > 1)
    """

    type: Literal[CustomDatasetType.MULTI_TURN] = CustomDatasetType.MULTI_TURN

    session_id: str | None = Field(
        None, description="Unique identifier for the conversation session"
    )
    turns: list[SingleTurn] = Field(
        ..., description="List of turns in the conversation"
    )

    @model_validator(mode="after")
    def validate_turns_not_empty(self) -> "MultiTurn":
        """Ensure at least one turn is provided"""
        if not self.turns:
            raise ValueError("At least one turn must be provided")
        return self

validate_turns_not_empty()

Ensure at least one turn is provided

Source code in aiperf/services/dataset/loader/models.py
91
92
93
94
95
96
@model_validator(mode="after")
def validate_turns_not_empty(self) -> "MultiTurn":
    """Ensure at least one turn is provided"""
    if not self.turns:
        raise ValueError("At least one turn must be provided")
    return self

RandomPool

Bases: AIPerfBaseModel

Defines the schema for random pool data entry.

The random pool custom dataset - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch size > 1) - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.) - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

Source code in aiperf/services/dataset/loader/models.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class RandomPool(AIPerfBaseModel):
    """Defines the schema for random pool data entry.

    The random pool custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch size > 1)
      - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.)
      - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)
    """

    type: Literal[CustomDatasetType.RANDOM_POOL] = CustomDatasetType.RANDOM_POOL

    text: str | None = Field(None, description="Simple text string content")
    texts: list[str] | list[Text] | None = Field(
        None,
        description="List of text strings or Text objects format",
    )
    image: str | None = Field(None, description="Simple image string content")
    images: list[str] | list[Image] | None = Field(
        None,
        description="List of image strings or Image objects format",
    )
    audio: str | None = Field(None, description="Simple audio string content")
    audios: list[str] | list[Audio] | None = Field(
        None,
        description="List of audio strings or Audio objects format",
    )

    @model_validator(mode="after")
    def validate_mutually_exclusive_fields(self) -> "RandomPool":
        """Ensure mutually exclusive fields are not set together"""
        if self.text and self.texts:
            raise ValueError("text and texts cannot be set together")
        if self.image and self.images:
            raise ValueError("image and images cannot be set together")
        if self.audio and self.audios:
            raise ValueError("audio and audios cannot be set together")
        return self

    @model_validator(mode="after")
    def validate_at_least_one_modality(self) -> "RandomPool":
        """Ensure at least one modality is provided"""
        if not any(
            [self.text, self.texts, self.image, self.images, self.audio, self.audios]
        ):
            raise ValueError("At least one modality must be provided")
        return self

validate_at_least_one_modality()

Ensure at least one modality is provided

Source code in aiperf/services/dataset/loader/models.py
138
139
140
141
142
143
144
145
@model_validator(mode="after")
def validate_at_least_one_modality(self) -> "RandomPool":
    """Ensure at least one modality is provided"""
    if not any(
        [self.text, self.texts, self.image, self.images, self.audio, self.audios]
    ):
        raise ValueError("At least one modality must be provided")
    return self

validate_mutually_exclusive_fields()

Ensure mutually exclusive fields are not set together

Source code in aiperf/services/dataset/loader/models.py
127
128
129
130
131
132
133
134
135
136
@model_validator(mode="after")
def validate_mutually_exclusive_fields(self) -> "RandomPool":
    """Ensure mutually exclusive fields are not set together"""
    if self.text and self.texts:
        raise ValueError("text and texts cannot be set together")
    if self.image and self.images:
        raise ValueError("image and images cannot be set together")
    if self.audio and self.audios:
        raise ValueError("audio and audios cannot be set together")
    return self

SingleTurn

Bases: AIPerfBaseModel

Defines the schema for single-turn data.

User can use this format to quickly provide a custom single turn dataset. Each line in the file will be treated as a single turn conversation.

The single turn type - supports multi-modal (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch_size > 1) - DOES NOT support multi-turn features (e.g. session_id)

Source code in aiperf/services/dataset/loader/models.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SingleTurn(AIPerfBaseModel):
    """Defines the schema for single-turn data.

    User can use this format to quickly provide a custom single turn dataset.
    Each line in the file will be treated as a single turn conversation.

    The single turn type
      - supports multi-modal (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch_size > 1)
      - DOES NOT support multi-turn features (e.g. session_id)
    """

    type: Literal[CustomDatasetType.SINGLE_TURN] = CustomDatasetType.SINGLE_TURN

    text: str | None = Field(None, description="Simple text string content")
    texts: list[str] | list[Text] | None = Field(
        None,
        description="List of text strings or Text objects format",
    )
    image: str | None = Field(None, description="Simple image string content")
    images: list[str] | list[Image] | None = Field(
        None,
        description="List of image strings or Image objects format",
    )
    audio: str | None = Field(None, description="Simple audio string content")
    audios: list[str] | list[Audio] | None = Field(
        None,
        description="List of audio strings or Audio objects format",
    )
    timestamp: int | None = Field(
        default=None, description="Timestamp of the turn in milliseconds."
    )
    delay: int | None = Field(
        default=None,
        description="Amount of milliseconds to wait before sending the turn.",
    )
    role: str | None = Field(default=None, description="Role of the turn.")

    @model_validator(mode="after")
    def validate_mutually_exclusive_fields(self) -> "SingleTurn":
        """Ensure mutually exclusive fields are not set together"""
        if self.text and self.texts:
            raise ValueError("text and texts cannot be set together")
        if self.image and self.images:
            raise ValueError("image and images cannot be set together")
        if self.audio and self.audios:
            raise ValueError("audio and audios cannot be set together")
        if self.timestamp and self.delay:
            raise ValueError("timestamp and delay cannot be set together")
        return self

    @model_validator(mode="after")
    def validate_at_least_one_modality(self) -> "SingleTurn":
        """Ensure at least one modality is provided"""
        if not any(
            [self.text, self.texts, self.image, self.images, self.audio, self.audios]
        ):
            raise ValueError("At least one modality must be provided")
        return self

validate_at_least_one_modality()

Ensure at least one modality is provided

Source code in aiperf/services/dataset/loader/models.py
63
64
65
66
67
68
69
70
@model_validator(mode="after")
def validate_at_least_one_modality(self) -> "SingleTurn":
    """Ensure at least one modality is provided"""
    if not any(
        [self.text, self.texts, self.image, self.images, self.audio, self.audios]
    ):
        raise ValueError("At least one modality must be provided")
    return self

validate_mutually_exclusive_fields()

Ensure mutually exclusive fields are not set together

Source code in aiperf/services/dataset/loader/models.py
50
51
52
53
54
55
56
57
58
59
60
61
@model_validator(mode="after")
def validate_mutually_exclusive_fields(self) -> "SingleTurn":
    """Ensure mutually exclusive fields are not set together"""
    if self.text and self.texts:
        raise ValueError("text and texts cannot be set together")
    if self.image and self.images:
        raise ValueError("image and images cannot be set together")
    if self.audio and self.audios:
        raise ValueError("audio and audios cannot be set together")
    if self.timestamp and self.delay:
        raise ValueError("timestamp and delay cannot be set together")
    return self

aiperf.services.dataset.loader.mooncake_trace

MooncakeTraceDatasetLoader

A dataset loader that loads Mooncake trace data from a file.

Loads Mooncake trace data from a file and converts the data into a list of conversations for dataset manager.

Each line in the file represents a single trace entry and will be converted to a separate conversation with a unique session ID.

Example: Fixed schedule version (Each line is a distinct session. Multi-turn is NOT supported)

{"timestamp": 1000, "input_length": 300, "output_length": 40, "hash_ids": [123, 456]}
Source code in aiperf/services/dataset/loader/mooncake_trace.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE)
class MooncakeTraceDatasetLoader:
    """A dataset loader that loads Mooncake trace data from a file.

    Loads Mooncake trace data from a file and converts the data into
    a list of conversations for dataset manager.

    Each line in the file represents a single trace entry and will be
    converted to a separate conversation with a unique session ID.

    Example:
    Fixed schedule version (Each line is a distinct session. Multi-turn is NOT supported)
    ```json
    {"timestamp": 1000, "input_length": 300, "output_length": 40, "hash_ids": [123, 456]}
    ```
    """

    def __init__(self, filename: str, prompt_generator: PromptGenerator):
        self.filename = filename
        self.prompt_generator = prompt_generator

    def load_dataset(self) -> dict[str, list[MooncakeTrace]]:
        """Load Mooncake trace data from a file.

        Returns:
            A dictionary of session_id and list of Mooncake trace data.
        """
        data: dict[str, list[MooncakeTrace]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                trace_data = MooncakeTrace.model_validate_json(line)
                session_id = str(uuid.uuid4())
                data[session_id].append(trace_data)

        return data

    def convert_to_conversations(
        self, data: dict[str, list[MooncakeTrace]]
    ) -> list[Conversation]:
        """Convert all the Mooncake trace data to conversation objects.

        Args:
            data: A dictionary of session_id and list of Mooncake trace data.

        Returns:
            A list of conversations.
        """
        conversations = []
        for session_id, traces in data.items():
            conversation = Conversation(session_id=session_id)
            for trace in traces:
                prompt = self.prompt_generator.generate(
                    mean=trace.input_length,
                    stddev=0,
                    hash_ids=trace.hash_ids,
                )
                turn = Turn(
                    timestamp=trace.timestamp,
                    texts=[Text(name="text", contents=[prompt])],
                )
                conversation.turns.append(turn)
            conversations.append(conversation)
        return conversations

convert_to_conversations(data)

Convert all the Mooncake trace data to conversation objects.

Parameters:

Name Type Description Default
data dict[str, list[MooncakeTrace]]

A dictionary of session_id and list of Mooncake trace data.

required

Returns:

Type Description
list[Conversation]

A list of conversations.

Source code in aiperf/services/dataset/loader/mooncake_trace.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def convert_to_conversations(
    self, data: dict[str, list[MooncakeTrace]]
) -> list[Conversation]:
    """Convert all the Mooncake trace data to conversation objects.

    Args:
        data: A dictionary of session_id and list of Mooncake trace data.

    Returns:
        A list of conversations.
    """
    conversations = []
    for session_id, traces in data.items():
        conversation = Conversation(session_id=session_id)
        for trace in traces:
            prompt = self.prompt_generator.generate(
                mean=trace.input_length,
                stddev=0,
                hash_ids=trace.hash_ids,
            )
            turn = Turn(
                timestamp=trace.timestamp,
                texts=[Text(name="text", contents=[prompt])],
            )
            conversation.turns.append(turn)
        conversations.append(conversation)
    return conversations

load_dataset()

Load Mooncake trace data from a file.

Returns:

Type Description
dict[str, list[MooncakeTrace]]

A dictionary of session_id and list of Mooncake trace data.

Source code in aiperf/services/dataset/loader/mooncake_trace.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def load_dataset(self) -> dict[str, list[MooncakeTrace]]:
    """Load Mooncake trace data from a file.

    Returns:
        A dictionary of session_id and list of Mooncake trace data.
    """
    data: dict[str, list[MooncakeTrace]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            trace_data = MooncakeTrace.model_validate_json(line)
            session_id = str(uuid.uuid4())
            data[session_id].append(trace_data)

    return data

aiperf.services.dataset.loader.multi_turn

MultiTurnDatasetLoader

A dataset loader that loads multi-turn data from a file.

The multi-turn type - supports multi-modal data (e.g. text, image, audio) - supports multi-turn features (e.g. delay, sessions, etc.) - supports client-side batching for each data (e.g. batch_size > 1)

NOTE: If the user specifies multiple multi-turn entries with same session ID, the loader will group them together. If the timestamps are specified, they will be sorted in ascending order later in the timing manager.

Examples: 1. Simple version

{
    "session_id": "session_123",
    "turns": [
        {"text": "Hello", "image": "url", "delay": 0},
        {"text": "Hi there", "delay": 1000}
    ]
}
  1. Batched version
{
    "session_id": "session_123",
    "turns": [
        {"texts": ["Who are you?", "Hello world"], "images": ["/path/1.png", "/path/2.png"]},
        {"texts": ["What is in the image?", "What is AI?"], "images": ["/path/3.png", "/path/4.png"]}
    ]
}
  1. Fixed schedule version
{
    "session_id": "session_123",
    "turns": [
        {"timestamp": 0, "text": "What is deep learning?"},
        {"timestamp": 1000, "text": "Who are you?"}
    ]
}
  1. Time delayed version
{
    "session_id": "session_123",
    "turns": [
        {"delay": 0, "text": "What is deep learning?"},
        {"delay": 1000, "text": "Who are you?"}
    ]
}
  1. full-featured version (multi-batch, multi-modal, multi-fielded, session-based, etc.)
{
    "session_id": "session_123",
    "turns": [
        {
            "timestamp": 1234,
            "texts": [
                {"name": "text_field_a", "contents": ["hello", "world"]},
                {"name": "text_field_b", "contents": ["hi there"]}
            ],
            "images": [
                {"name": "image_field_a", "contents": ["/path/1.png", "/path/2.png"]},
                {"name": "image_field_b", "contents": ["/path/3.png"]}
            ]
        }
    ]
}
Source code in aiperf/services/dataset/loader/multi_turn.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@CustomDatasetFactory.register(CustomDatasetType.MULTI_TURN)
class MultiTurnDatasetLoader:
    """A dataset loader that loads multi-turn data from a file.

    The multi-turn type
      - supports multi-modal data (e.g. text, image, audio)
      - supports multi-turn features (e.g. delay, sessions, etc.)
      - supports client-side batching for each data (e.g. batch_size > 1)

    NOTE: If the user specifies multiple multi-turn entries with same session ID,
    the loader will group them together. If the timestamps are specified, they will
    be sorted in ascending order later in the timing manager.

    Examples:
    1. Simple version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"text": "Hello", "image": "url", "delay": 0},
            {"text": "Hi there", "delay": 1000}
        ]
    }
    ```

    2. Batched version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"texts": ["Who are you?", "Hello world"], "images": ["/path/1.png", "/path/2.png"]},
            {"texts": ["What is in the image?", "What is AI?"], "images": ["/path/3.png", "/path/4.png"]}
        ]
    }
    ```

    3. Fixed schedule version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"timestamp": 0, "text": "What is deep learning?"},
            {"timestamp": 1000, "text": "Who are you?"}
        ]
    }
    ```

    4. Time delayed version
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {"delay": 0, "text": "What is deep learning?"},
            {"delay": 1000, "text": "Who are you?"}
        ]
    }
    ```

    5. full-featured version (multi-batch, multi-modal, multi-fielded, session-based, etc.)
    ```json
    {
        "session_id": "session_123",
        "turns": [
            {
                "timestamp": 1234,
                "texts": [
                    {"name": "text_field_a", "contents": ["hello", "world"]},
                    {"name": "text_field_b", "contents": ["hi there"]}
                ],
                "images": [
                    {"name": "image_field_a", "contents": ["/path/1.png", "/path/2.png"]},
                    {"name": "image_field_b", "contents": ["/path/3.png"]}
                ]
            }
        ]
    }
    ```
    """

    def __init__(self, filename: str):
        self.filename = filename

    def load_dataset(self) -> dict[str, list[MultiTurn]]:
        """Load multi-turn data from a JSONL file.

        Each line represents a complete multi-turn conversation with its own
        session_id and multiple turns.

        Returns:
            A dictionary mapping session_id to list of CustomData (containing the MultiTurn).
        """
        data: dict[str, list[MultiTurn]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                multi_turn_data = MultiTurn.model_validate_json(line)
                session_id = multi_turn_data.session_id or str(uuid.uuid4())
                data[session_id].append(multi_turn_data)

        return data

load_dataset()

Load multi-turn data from a JSONL file.

Each line represents a complete multi-turn conversation with its own session_id and multiple turns.

Returns:

Type Description
dict[str, list[MultiTurn]]

A dictionary mapping session_id to list of CustomData (containing the MultiTurn).

Source code in aiperf/services/dataset/loader/multi_turn.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def load_dataset(self) -> dict[str, list[MultiTurn]]:
    """Load multi-turn data from a JSONL file.

    Each line represents a complete multi-turn conversation with its own
    session_id and multiple turns.

    Returns:
        A dictionary mapping session_id to list of CustomData (containing the MultiTurn).
    """
    data: dict[str, list[MultiTurn]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            multi_turn_data = MultiTurn.model_validate_json(line)
            session_id = multi_turn_data.session_id or str(uuid.uuid4())
            data[session_id].append(multi_turn_data)

    return data

aiperf.services.dataset.loader.protocol

aiperf.services.dataset.loader.random_pool

RandomPoolDatasetLoader

A dataset loader that loads data from a single file or a directory.

Each line in the file represents single-turn conversation data, and files create individual pools for random sampling: - Single file: All lines form one single pool (to be randomly sampled from) - Directory: Each file becomes a separate pool, then pools are randomly sampled and merged into conversations later.

The random pool custom dataset - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch size > 1) - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.) - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

Example:

  1. Single file
{"text": "Who are you?", "image": "/path/to/image1.png"}
{"text": "Explain what is the meaning of life.", "image": "/path/to/image2.png"}
...

The file will form a single pool of text and image data that will be used to generate conversations.

  1. Directory

Directory will be useful if user wants to - create multiple pools of different modalities separately (e.g. text, image) - specify different field names for the same modality.

data/queries.jsonl

{"texts": [{"name": "query", "contents": ["Who are you?"]}]}
{"texts": [{"name": "query", "contents": ["What is the meaning of life?"]}]}
...

data/passages.jsonl

{"texts": [{"name": "passage", "contents": ["I am a cat."]}]}
{"texts": [{"name": "passage", "contents": ["I am a dog."]}]}
...

The loader will create two separate pools for each file: queries and passages. Each pool is a text dataset with a different field name (e.g. query, passage), and loader will later sample from these two pools to create conversations.

Source code in aiperf/services/dataset/loader/random_pool.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@CustomDatasetFactory.register(CustomDatasetType.RANDOM_POOL)
class RandomPoolDatasetLoader:
    """A dataset loader that loads data from a single file or a directory.

    Each line in the file represents single-turn conversation data,
    and files create individual pools for random sampling:
      - Single file: All lines form one single pool (to be randomly sampled from)
      - Directory: Each file becomes a separate pool, then pools are randomly sampled
                   and merged into conversations later.

    The random pool custom dataset
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch size > 1)
      - supports named fields for each modality (e.g. text_field_a, text_field_b, etc.)
      - DOES NOT support multi-turn or its features (e.g. delay, sessions, etc.)

    Example:

    1. Single file
    ```jsonl
    {"text": "Who are you?", "image": "/path/to/image1.png"}
    {"text": "Explain what is the meaning of life.", "image": "/path/to/image2.png"}
    ...
    ```
    The file will form a single pool of text and image data that will be used
    to generate conversations.

    2. Directory

    Directory will be useful if user wants to
      - create multiple pools of different modalities separately (e.g. text, image)
      - specify different field names for the same modality.

    data/queries.jsonl
    ```jsonl
    {"texts": [{"name": "query", "contents": ["Who are you?"]}]}
    {"texts": [{"name": "query", "contents": ["What is the meaning of life?"]}]}
    ...
    ```

    data/passages.jsonl
    ```jsonl
    {"texts": [{"name": "passage", "contents": ["I am a cat."]}]}
    {"texts": [{"name": "passage", "contents": ["I am a dog."]}]}
    ...
    ```

    The loader will create two separate pools for each file: queries and passages.
    Each pool is a text dataset with a different field name (e.g. query, passage),
    and loader will later sample from these two pools to create conversations.
    """

    def __init__(self, filename: str):
        self.filename = filename

    def load_dataset(self) -> dict[Filename, list[RandomPool]]:
        """Load random pool data from a file or directory.

        If filename is a file, reads and parses using the RandomPool model.
        If filename is a directory, reads each file in the directory and merges
        items with different modality names into combined RandomPool objects.

        Returns:
            A dictionary mapping filename to list of RandomPool objects.
        """
        path = Path(self.filename)

        if path.is_file():
            dataset_pool = self._load_dataset_from_file(path)
            return {path.name: dataset_pool}

        return self._load_dataset_from_dir(path)

    def _load_dataset_from_file(self, file_path: Path) -> list[RandomPool]:
        """Load random pool data from a single file.

        Args:
            file_path: The path to the file containing the data.

        Returns:
            A list of RandomPool objects.
        """
        dataset_pool: list[RandomPool] = []

        with open(file_path) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                random_pool_data = RandomPool.model_validate_json(line)
                dataset_pool.append(random_pool_data)

        return dataset_pool

    def _load_dataset_from_dir(
        self, dir_path: Path
    ) -> dict[Filename, list[RandomPool]]:
        """Load random pool data from all files in a directory.

        Args:
            dir_path: The path to the directory containing the files.

        Returns:
            A dictionary mapping filename to list of RandomPool objects.
        """
        data: dict[Filename, list[RandomPool]] = defaultdict(list)

        for file_path in dir_path.iterdir():
            if file_path.is_file():
                dataset_pool = self._load_dataset_from_file(file_path)
                data[file_path.name].extend(dataset_pool)

        return data

load_dataset()

Load random pool data from a file or directory.

If filename is a file, reads and parses using the RandomPool model. If filename is a directory, reads each file in the directory and merges items with different modality names into combined RandomPool objects.

Returns:

Type Description
dict[Filename, list[RandomPool]]

A dictionary mapping filename to list of RandomPool objects.

Source code in aiperf/services/dataset/loader/random_pool.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def load_dataset(self) -> dict[Filename, list[RandomPool]]:
    """Load random pool data from a file or directory.

    If filename is a file, reads and parses using the RandomPool model.
    If filename is a directory, reads each file in the directory and merges
    items with different modality names into combined RandomPool objects.

    Returns:
        A dictionary mapping filename to list of RandomPool objects.
    """
    path = Path(self.filename)

    if path.is_file():
        dataset_pool = self._load_dataset_from_file(path)
        return {path.name: dataset_pool}

    return self._load_dataset_from_dir(path)

aiperf.services.dataset.loader.single_turn

SingleTurnDatasetLoader

A dataset loader that loads single turn data from a file.

The single turn type - supports multi-modal data (e.g. text, image, audio) - supports client-side batching for each data (e.g. batch_size > 1) - DOES NOT support multi-turn features (e.g. delay, sessions, etc.)

Examples: 1. Single-batch, text only

{"text": "What is deep learning?"}
  1. Single-batch, multi-modal
{"text": "What is in the image?", "image": "/path/to/image.png"}
  1. Multi-batch, multi-modal
{"texts": ["Who are you?", "Hello world"], "images": ["/path/to/image.png", "/path/to/image2.png"]}
  1. Fixed schedule version
{"timestamp": 0, "text": "What is deep learning?"},
{"timestamp": 1000, "text": "Who are you?"},
{"timestamp": 2000, "text": "What is AI?"}
  1. Time delayed version
{"delay": 0, "text": "What is deep learning?"},
{"delay": 1234, "text": "Who are you?"}
  1. Full-featured version (Multi-batch, multi-modal, multi-fielded)
{
    "texts": [
        {"name": "text_field_A", "contents": ["Hello", "World"]},
        {"name": "text_field_B", "contents": ["Hi there"]}
    ],
    "images": [
        {"name": "image_field_A", "contents": ["/path/1.png", "/path/2.png"]},
        {"name": "image_field_B", "contents": ["/path/3.png"]}
    ]
}
Source code in aiperf/services/dataset/loader/single_turn.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@CustomDatasetFactory.register(CustomDatasetType.SINGLE_TURN)
class SingleTurnDatasetLoader:
    """A dataset loader that loads single turn data from a file.

    The single turn type
      - supports multi-modal data (e.g. text, image, audio)
      - supports client-side batching for each data (e.g. batch_size > 1)
      - DOES NOT support multi-turn features (e.g. delay, sessions, etc.)

    Examples:
    1. Single-batch, text only
    ```json
    {"text": "What is deep learning?"}
    ```

    2. Single-batch, multi-modal
    ```json
    {"text": "What is in the image?", "image": "/path/to/image.png"}
    ```

    3. Multi-batch, multi-modal
    ```json
    {"texts": ["Who are you?", "Hello world"], "images": ["/path/to/image.png", "/path/to/image2.png"]}
    ```

    4. Fixed schedule version
    ```json
    {"timestamp": 0, "text": "What is deep learning?"},
    {"timestamp": 1000, "text": "Who are you?"},
    {"timestamp": 2000, "text": "What is AI?"}
    ```

    5. Time delayed version
    ```json
    {"delay": 0, "text": "What is deep learning?"},
    {"delay": 1234, "text": "Who are you?"}
    ```

    6. Full-featured version (Multi-batch, multi-modal, multi-fielded)
    ```json
    {
        "texts": [
            {"name": "text_field_A", "contents": ["Hello", "World"]},
            {"name": "text_field_B", "contents": ["Hi there"]}
        ],
        "images": [
            {"name": "image_field_A", "contents": ["/path/1.png", "/path/2.png"]},
            {"name": "image_field_B", "contents": ["/path/3.png"]}
        ]
    }
    ```
    """

    def __init__(self, filename: str):
        self.filename = filename

    def load_dataset(self) -> dict[str, list[SingleTurn]]:
        """Load single-turn data from a JSONL file.

        Each line represents a single turn conversation. Multiple turns with
        the same session_id (or generated UUID) are grouped together.

        Returns:
            A dictionary mapping session_id to list of CustomData.
        """
        data: dict[str, list[SingleTurn]] = defaultdict(list)

        with open(self.filename) as f:
            for line in f:
                if (line := line.strip()) == "":
                    continue  # Skip empty lines

                single_turn_data = SingleTurn.model_validate_json(line)
                session_id = str(uuid.uuid4())
                data[session_id].append(single_turn_data)

        return data

load_dataset()

Load single-turn data from a JSONL file.

Each line represents a single turn conversation. Multiple turns with the same session_id (or generated UUID) are grouped together.

Returns:

Type Description
dict[str, list[SingleTurn]]

A dictionary mapping session_id to list of CustomData.

Source code in aiperf/services/dataset/loader/single_turn.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def load_dataset(self) -> dict[str, list[SingleTurn]]:
    """Load single-turn data from a JSONL file.

    Each line represents a single turn conversation. Multiple turns with
    the same session_id (or generated UUID) are grouped together.

    Returns:
        A dictionary mapping session_id to list of CustomData.
    """
    data: dict[str, list[SingleTurn]] = defaultdict(list)

    with open(self.filename) as f:
        for line in f:
            if (line := line.strip()) == "":
                continue  # Skip empty lines

            single_turn_data = SingleTurn.model_validate_json(line)
            session_id = str(uuid.uuid4())
            data[session_id].append(single_turn_data)

    return data

aiperf.services.dataset.utils

check_file_exists(filename)

Verifies that the file exists.

Parameters:

Name Type Description Default
filename

The file path to verify.

required

Raises:

Type Description
FileNotFoundError

If the file does not exist.

Source code in aiperf/services/dataset/utils.py
18
19
20
21
22
23
24
25
26
27
28
def check_file_exists(filename: Path) -> None:
    """Verifies that the file exists.

    Args:
        filename : The file path to verify.

    Raises:
        FileNotFoundError: If the file does not exist.
    """
    if not filename.exists():
        raise FileNotFoundError(f"The file '{filename}' does not exist.")

encode_image(img, format)

Encodes an image into base64 encoded string.

Parameters:

Name Type Description Default
img Image

The PIL Image object to encode.

required
format str

The image format to use (e.g., "JPEG", "PNG").

required

Returns:

Type Description
str

A base64 encoded string representation of the image.

Source code in aiperf/services/dataset/utils.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def encode_image(img: Image, format: str) -> str:
    """Encodes an image into base64 encoded string.

    Args:
        img: The PIL Image object to encode.
        format: The image format to use (e.g., "JPEG", "PNG").

    Returns:
        A base64 encoded string representation of the image.
    """
    # JPEG does not support P or RGBA mode (commonly used for PNG) so it needs
    # to be converted to RGB before an image can be saved as JPEG format.
    if format == "JPEG" and img.mode != "RGB":
        img = img.convert("RGB")

    buffer = BytesIO()
    img.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

load_json_str(json_str, func=lambda x: x)

Deserializes JSON encoded string into Python object.

Parameters:

Name Type Description Default
json_str str

JSON encoded string

required
func Callable

A function that takes deserialized JSON object. This can be used to run validation checks on the object. Defaults to identity function.

lambda x: x

Returns:

Type Description
dict[str, Any]

The deserialized JSON object.

Raises:

Type Description
RuntimeError

If the JSON string is invalid.

Source code in aiperf/services/dataset/utils.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def load_json_str(json_str: str, func: Callable = lambda x: x) -> dict[str, Any]:
    """Deserializes JSON encoded string into Python object.

    Args:
        json_str: JSON encoded string
        func: A function that takes deserialized JSON object. This can be used to
            run validation checks on the object. Defaults to identity function.

    Returns:
        The deserialized JSON object.

    Raises:
        RuntimeError: If the JSON string is invalid.
    """
    try:
        # TODO: Consider using orjson for faster JSON parsing
        return func(json.loads(json_str))
    except json.JSONDecodeError as e:
        snippet = json_str[:200] + ("..." if len(json_str) > 200 else "")
        raise RuntimeError(f"Failed to parse JSON string: '{snippet}'") from e

open_image(filename)

Opens an image file.

Parameters:

Name Type Description Default
filename

The file path to open.

required

Returns:

Type Description
Image

The opened PIL Image object.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

Source code in aiperf/services/dataset/utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def open_image(filename: str) -> Image:
    """Opens an image file.

    Args:
        filename : The file path to open.

    Returns:
        The opened PIL Image object.

    Raises:
        FileNotFoundError: If the file does not exist.
    """
    check_file_exists(Path(filename))
    img = Image.open(filename)

    if img.format is None:
        raise RuntimeError(f"Failed to determine image format of '{filename}'.")

    if img.format.upper() not in list(ImageFormat):
        raise RuntimeError(
            f"'{img.format}' is not one of the supported image formats: "
            f"{', '.join(ImageFormat)}"
        )
    return img

sample_normal(mean, stddev, lower=-np.inf, upper=np.inf)

Sample from a normal distribution with support for bounds using rejection sampling.

Parameters:

Name Type Description Default
mean float

The mean of the normal distribution.

required
stddev float

The standard deviation of the normal distribution.

required
lower float

The lower bound of the distribution.

-inf
upper float

The upper bound of the distribution.

inf

Returns:

Type Description
int

An integer sampled from the distribution.

Source code in aiperf/services/dataset/utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def sample_normal(
    mean: float, stddev: float, lower: float = -np.inf, upper: float = np.inf
) -> int:
    """Sample from a normal distribution with support for bounds using rejection sampling.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.
        lower: The lower bound of the distribution.
        upper: The upper bound of the distribution.

    Returns:
        An integer sampled from the distribution.
    """
    while True:
        n = np.random.normal(mean, stddev)
        if lower <= n <= upper:
            return n

sample_positive_normal(mean, stddev)

Sample from a normal distribution ensuring positive values without distorting the distribution.

Parameters:

Name Type Description Default
mean float

Mean value for the normal distribution

required
stddev float

Standard deviation for the normal distribution

required

Returns:

Type Description
float

A positive sample from the normal distribution

Raises:

Type Description
ValueError

If mean is less than 0

Source code in aiperf/services/dataset/utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def sample_positive_normal(mean: float, stddev: float) -> float:
    """Sample from a normal distribution ensuring positive values
    without distorting the distribution.

    Args:
        mean: Mean value for the normal distribution
        stddev: Standard deviation for the normal distribution

    Returns:
        A positive sample from the normal distribution

    Raises:
        ValueError: If mean is less than 0
    """
    if mean < 0:
        raise ValueError(f"Mean value ({mean}) should be greater than 0")
    return sample_normal(mean, stddev, lower=0)

sample_positive_normal_integer(mean, stddev)

Sample a random positive integer from a normal distribution.

Parameters:

Name Type Description Default
mean float

The mean of the normal distribution.

required
stddev float

The standard deviation of the normal distribution.

required

Returns:

Type Description
int

A positive integer sampled from the distribution. If the sampled

int

number is less than 1, it returns 1.

Source code in aiperf/services/dataset/utils.py
138
139
140
141
142
143
144
145
146
147
148
149
def sample_positive_normal_integer(mean: float, stddev: float) -> int:
    """Sample a random positive integer from a normal distribution.

    Args:
        mean: The mean of the normal distribution.
        stddev: The standard deviation of the normal distribution.

    Returns:
        A positive integer sampled from the distribution. If the sampled
        number is less than 1, it returns 1.
    """
    return math.ceil(sample_positive_normal(mean, stddev))

aiperf.services.inference_result_parser.inference_result_parser

InferenceResultParser

Bases: BaseComponentService

InferenceResultParser is responsible for parsing the inference results and pushing them to the RecordsManager.

Source code in aiperf/services/inference_result_parser/inference_result_parser.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@ServiceFactory.register(ServiceType.INFERENCE_RESULT_PARSER)
class InferenceResultParser(BaseComponentService):
    """InferenceResultParser is responsible for parsing the inference results
    and pushing them to the RecordsManager.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.debug("Initializing inference result parser")
        self.inference_results_client: PullClientProtocol = (
            self.comms.create_pull_client(
                CommunicationClientAddressType.RAW_INFERENCE_PROXY_BACKEND,
            )
        )
        self.records_push_client: PushClientProtocol = self.comms.create_push_client(
            CommunicationClientAddressType.RECORDS,
        )
        self.conversation_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommunicationClientAddressType.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )
        self.tokenizers: dict[str, Tokenizer] = {}
        self.user_config: UserConfig = user_config
        self.tokenizer_lock: asyncio.Lock = asyncio.Lock()
        self.model_endpoint: ModelEndpointInfo = ModelEndpointInfo.from_user_config(
            user_config
        )

    @property
    def service_type(self) -> ServiceType:
        """The type of service."""
        return ServiceType.INFERENCE_RESULT_PARSER

    @on_init
    async def _initialize(self) -> None:
        """Initialize inference result parser-specific components."""
        self.debug("Initializing inference result parser")

        await self.inference_results_client.register_pull_callback(
            message_type=MessageType.INFERENCE_RESULTS,
            callback=self._on_inference_results,
            # TODO: Support for unbounded concurrency in the future by setting to None or 0?
            max_concurrency=1_000_000,
        )

        self.extractor = ResponseExtractorFactory.create_instance(
            self.model_endpoint.endpoint.type,
            model_endpoint=self.model_endpoint,
        )

        async with self.tokenizer_lock:
            self.tokenizers = {
                model.name: Tokenizer.from_pretrained(
                    self.user_config.tokenizer.name or model.name,
                    trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                    revision=self.user_config.tokenizer.revision,
                )
                for model in self.model_endpoint.models.models
            }
            self.info("Initialized tokenizers for %d models", len(self.tokenizers))

    async def get_tokenizer(self, model: str) -> Tokenizer:
        """Get the tokenizer for a given model."""
        async with self.tokenizer_lock:
            if model not in self.tokenizers:
                self.tokenizers[model] = Tokenizer.from_pretrained(
                    self.user_config.tokenizer.name or model,
                    trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                    revision=self.user_config.tokenizer.revision,
                )
            return self.tokenizers[model]

    @on_configure
    async def _configure(self, message: CommandMessage) -> None:
        """Configure the inference result parser."""

    async def _on_inference_results(self, message: InferenceResultsMessage) -> None:
        """Handle an inference results message."""
        self.debug(lambda: f"Received inference results message: {message}")

        if message.record.has_error:
            await self.records_push_client.push(
                ParsedInferenceResultsMessage(
                    service_id=self.service_id,
                    record=ParsedResponseRecord(
                        worker_id=message.service_id,
                        request=message.record,
                        responses=[],
                    ),
                )
            )

        elif message.record.valid:
            try:
                record = await self.process_valid_record(message)
                self.debug(
                    lambda: f"Received {len(record.request.responses)} responses, input_token_count: {record.input_token_count}, output_token_count: {record.output_token_count}"
                )
                await self.records_push_client.push(
                    ParsedInferenceResultsMessage(
                        service_id=self.service_id,
                        record=record,
                    )
                )
            except Exception as e:
                self.exception(f"Error processing valid record: {e}")
                await self.records_push_client.push(
                    ParsedInferenceResultsMessage(
                        service_id=self.service_id,
                        record=ParsedResponseRecord(
                            worker_id=message.service_id,
                            request=message.record,
                            responses=[],
                        ),
                    )
                )
        else:
            self.warning(f"Received invalid inference results: {message.record}")
            message.record.error = ErrorDetails(
                code=None,
                message="Invalid inference results",
                type="InvalidInferenceResults",
            )
            await self.records_push_client.push(
                ParsedInferenceResultsMessage(
                    service_id=self.service_id,
                    record=ParsedResponseRecord(
                        worker_id=message.service_id,
                        request=message.record,
                        responses=[],
                    ),
                )
            )

    async def process_valid_record(
        self, message: InferenceResultsMessage
    ) -> ParsedResponseRecord:
        """Process a valid request record."""
        if message.record.model_name is None:
            self.warning(
                lambda: f"Model name is None, unable to process record: {message.record}"
            )
            return ParsedResponseRecord(
                worker_id=message.service_id,
                request=message.record,
                responses=[],
                input_token_count=None,
                output_token_count=None,
            )

        tokenizer = await self.get_tokenizer(message.record.model_name)
        resp = await self.extractor.extract_response_data(message.record, tokenizer)
        input_token_count = await self.compute_input_token_count(
            message.record, tokenizer
        )
        output_token_count = sum(
            response.token_count
            for response in resp
            if response.token_count is not None
        )

        return ParsedResponseRecord(
            worker_id=message.service_id,
            request=message.record,
            responses=resp,
            input_token_count=input_token_count,
            output_token_count=output_token_count,
        )

    async def compute_input_token_count(
        self, record: RequestRecord, tokenizer: Tokenizer
    ) -> int | None:
        """Compute the number of tokens in the input for a given request record."""
        if record.conversation_id is None or record.turn_index is None:
            self.warning(
                lambda: f"Conversation ID or turn index is None: {record.conversation_id=} {record.turn_index=}"
            )
            return None

        turn_response: ConversationTurnResponseMessage = (
            await self.conversation_request_client.request(
                ConversationTurnRequestMessage(
                    service_id=self.service_id,
                    conversation_id=record.conversation_id,
                    turn_index=record.turn_index,
                )
            )
        )
        if isinstance(turn_response, ErrorMessage):
            self.error(lambda: f"Error getting turn response: {turn_response}")
            return None

        turn = turn_response.turn
        return sum(
            len(tokenizer.encode(content))
            for text in turn.texts
            for content in text.contents
        )

service_type property

The type of service.

compute_input_token_count(record, tokenizer) async

Compute the number of tokens in the input for a given request record.

Source code in aiperf/services/inference_result_parser/inference_result_parser.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
async def compute_input_token_count(
    self, record: RequestRecord, tokenizer: Tokenizer
) -> int | None:
    """Compute the number of tokens in the input for a given request record."""
    if record.conversation_id is None or record.turn_index is None:
        self.warning(
            lambda: f"Conversation ID or turn index is None: {record.conversation_id=} {record.turn_index=}"
        )
        return None

    turn_response: ConversationTurnResponseMessage = (
        await self.conversation_request_client.request(
            ConversationTurnRequestMessage(
                service_id=self.service_id,
                conversation_id=record.conversation_id,
                turn_index=record.turn_index,
            )
        )
    )
    if isinstance(turn_response, ErrorMessage):
        self.error(lambda: f"Error getting turn response: {turn_response}")
        return None

    turn = turn_response.turn
    return sum(
        len(tokenizer.encode(content))
        for text in turn.texts
        for content in text.contents
    )

get_tokenizer(model) async

Get the tokenizer for a given model.

Source code in aiperf/services/inference_result_parser/inference_result_parser.py
109
110
111
112
113
114
115
116
117
118
async def get_tokenizer(self, model: str) -> Tokenizer:
    """Get the tokenizer for a given model."""
    async with self.tokenizer_lock:
        if model not in self.tokenizers:
            self.tokenizers[model] = Tokenizer.from_pretrained(
                self.user_config.tokenizer.name or model,
                trust_remote_code=self.user_config.tokenizer.trust_remote_code,
                revision=self.user_config.tokenizer.revision,
            )
        return self.tokenizers[model]

process_valid_record(message) async

Process a valid request record.

Source code in aiperf/services/inference_result_parser/inference_result_parser.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
async def process_valid_record(
    self, message: InferenceResultsMessage
) -> ParsedResponseRecord:
    """Process a valid request record."""
    if message.record.model_name is None:
        self.warning(
            lambda: f"Model name is None, unable to process record: {message.record}"
        )
        return ParsedResponseRecord(
            worker_id=message.service_id,
            request=message.record,
            responses=[],
            input_token_count=None,
            output_token_count=None,
        )

    tokenizer = await self.get_tokenizer(message.record.model_name)
    resp = await self.extractor.extract_response_data(message.record, tokenizer)
    input_token_count = await self.compute_input_token_count(
        message.record, tokenizer
    )
    output_token_count = sum(
        response.token_count
        for response in resp
        if response.token_count is not None
    )

    return ParsedResponseRecord(
        worker_id=message.service_id,
        request=message.record,
        responses=resp,
        input_token_count=input_token_count,
        output_token_count=output_token_count,
    )

aiperf.services.inference_result_parser.openai_parsers

OpenAIObject

Bases: CaseInsensitiveStrEnum

Types of OpenAI objects.

Source code in aiperf/services/inference_result_parser/openai_parsers.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class OpenAIObject(CaseInsensitiveStrEnum):
    """Types of OpenAI objects."""

    CHAT_COMPLETION = "chat.completion"
    CHAT_COMPLETION_CHUNK = "chat.completion.chunk"
    COMPLETION = "completion"
    EMBEDDING = "embedding"
    RESPONSE = "response"

    @classmethod
    def parse(cls, text: str) -> BaseModel:
        """Attempt to parse a string into an OpenAI object.

        Raises:
            ValueError: If the object is invalid.
        """
        try:
            obj = load_json_str(text)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid OpenAI object: {text}") from e

        # Mapping of OpenAI object types to their corresponding Pydantic models.
        _object_mapping: dict[str, type[BaseModel]] = {
            cls.CHAT_COMPLETION: ChatCompletion,
            cls.CHAT_COMPLETION_CHUNK: ChatCompletionChunk,
            cls.COMPLETION: Completion,
            cls.EMBEDDING: Embedding,
            cls.RESPONSE: ResponsesModel,
        }

        obj_type = obj.get("object")
        if obj_type is None:
            raise ValueError(f"Invalid OpenAI object: {obj}")

        if obj_type not in _object_mapping:
            raise ValueError(f"Invalid OpenAI object type: {obj_type}")

        try:
            return _object_mapping[obj_type](**obj)
        except Exception as e:
            raise ValueError(f"Invalid OpenAI object: {text}") from e

parse(text) classmethod

Attempt to parse a string into an OpenAI object.

Raises:

Type Description
ValueError

If the object is invalid.

Source code in aiperf/services/inference_result_parser/openai_parsers.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@classmethod
def parse(cls, text: str) -> BaseModel:
    """Attempt to parse a string into an OpenAI object.

    Raises:
        ValueError: If the object is invalid.
    """
    try:
        obj = load_json_str(text)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid OpenAI object: {text}") from e

    # Mapping of OpenAI object types to their corresponding Pydantic models.
    _object_mapping: dict[str, type[BaseModel]] = {
        cls.CHAT_COMPLETION: ChatCompletion,
        cls.CHAT_COMPLETION_CHUNK: ChatCompletionChunk,
        cls.COMPLETION: Completion,
        cls.EMBEDDING: Embedding,
        cls.RESPONSE: ResponsesModel,
    }

    obj_type = obj.get("object")
    if obj_type is None:
        raise ValueError(f"Invalid OpenAI object: {obj}")

    if obj_type not in _object_mapping:
        raise ValueError(f"Invalid OpenAI object type: {obj_type}")

    try:
        return _object_mapping[obj_type](**obj)
    except Exception as e:
        raise ValueError(f"Invalid OpenAI object: {text}") from e

OpenAIResponseExtractor

Extractor for OpenAI responses.

Source code in aiperf/services/inference_result_parser/openai_parsers.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@ResponseExtractorFactory.register_all(
    EndpointType.OPENAI_CHAT_COMPLETIONS,
    EndpointType.OPENAI_COMPLETIONS,
    EndpointType.OPENAI_RESPONSES,
)
class OpenAIResponseExtractor:
    """Extractor for OpenAI responses."""

    def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
        """Create a new response extractor based on the provided configuration."""
        self.model_endpoint = model_endpoint

    def _parse_text_response(self, response: TextResponse) -> ResponseData | None:
        """Parse a TextResponse into a ResponseData object."""
        raw = response.text
        parsed = self._parse_text(raw)
        if parsed is None:
            return None

        return ResponseData(
            perf_ns=response.perf_ns,
            raw_text=[raw],
            parsed_text=[parsed],
            metadata={},
        )

    def _parse_sse_response(self, response: SSEMessage) -> ResponseData | None:
        """Parse a SSEMessage into a ResponseData object."""
        raw = response.extract_data_content()
        parsed = self._parse_sse(raw)
        if parsed is None or len(parsed) == 0:
            return None

        return ResponseData(
            perf_ns=response.perf_ns,
            raw_text=raw,
            parsed_text=parsed,
            metadata={},
        )

    def _parse_response(self, response: InferenceServerResponse) -> ResponseData | None:
        """Parse a response into a ResponseData object."""
        if isinstance(response, TextResponse):
            return self._parse_text_response(response)
        elif isinstance(response, SSEMessage):
            return self._parse_sse_response(response)

    async def extract_response_data(
        self, record: RequestRecord, tokenizer: Tokenizer | None
    ) -> list[ResponseData]:
        """Extract the text from a server response message."""
        results = []
        for response in record.responses:
            response_data = self._parse_response(response)
            if response_data is None:
                continue

            if tokenizer is not None:
                response_data.token_count = sum(
                    len(tokenizer.encode(text))
                    for text in response_data.parsed_text
                    if text is not None
                )
            results.append(response_data)
        return results

    def _parse_text(self, raw_text: str) -> Any | None:
        """Parse the text of the response."""
        if raw_text in ("", None, "[DONE]"):
            return None

        obj = OpenAIObject.parse(raw_text)

        # Dictionary mapping object types to their value extraction functions
        type_to_extractor = {
            # TODO: how to support multiple choices?
            ChatCompletion: lambda obj: obj.choices[0].message.content,
            # TODO: how to support multiple choices?
            ChatCompletionChunk: lambda obj: obj.choices[0].delta.content,
            # TODO: how to support multiple choices?
            Completion: lambda obj: obj.choices[0].text,
            Embedding: lambda obj: obj.embedding,
            ResponsesModel: lambda obj: obj.output_text,
        }

        for obj_type, extractor in type_to_extractor.items():
            if isinstance(obj, obj_type):
                return extractor(obj)

        raise ValueError(f"Invalid OpenAI object: {raw_text}")

    def _parse_sse(self, raw_sse: list[str]) -> list[Any]:
        """Parse the SSE of the response."""
        result = []
        for sse in raw_sse:
            parsed = self._parse_text(sse)
            if parsed is None:
                continue
            result.append(parsed)
        return result

__init__(model_endpoint)

Create a new response extractor based on the provided configuration.

Source code in aiperf/services/inference_result_parser/openai_parsers.py
83
84
85
def __init__(self, model_endpoint: ModelEndpointInfo) -> None:
    """Create a new response extractor based on the provided configuration."""
    self.model_endpoint = model_endpoint

extract_response_data(record, tokenizer) async

Extract the text from a server response message.

Source code in aiperf/services/inference_result_parser/openai_parsers.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
async def extract_response_data(
    self, record: RequestRecord, tokenizer: Tokenizer | None
) -> list[ResponseData]:
    """Extract the text from a server response message."""
    results = []
    for response in record.responses:
        response_data = self._parse_response(response)
        if response_data is None:
            continue

        if tokenizer is not None:
            response_data.token_count = sum(
                len(tokenizer.encode(text))
                for text in response_data.parsed_text
                if text is not None
            )
        results.append(response_data)
    return results

aiperf.services.records_manager.metrics.base_metric

BaseMetric

Bases: ABC

Base class for all metrics with automatic subclass registration.

Source code in aiperf/services/records_manager/metrics/base_metric.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class BaseMetric(ABC):
    """Base class for all metrics with automatic subclass registration."""

    # Class attributes that subclasses must override
    tag: ClassVar[str] = ""
    unit: ClassVar[MetricTimeType] = MetricTimeType.NANOSECONDS
    larger_is_better: ClassVar[bool] = True
    header: ClassVar[str] = ""
    streaming_only: ClassVar[bool] = False
    base_metrics: set[str] = set()

    metric_interfaces: dict[str, type["BaseMetric"]] = {}

    def __init_subclass__(cls, **kwargs):
        """
        This method is called when a class is subclassed from Metric.
        It automatically registers the subclass in the metric_interfaces
        dictionary using the `tag` class attribute.
        The `tag` attribute must be a non-empty string that uniquely identifies the
        metric type. Only concrete (non-abstract) classes will be registered.
        """

        super().__init_subclass__(**kwargs)

        # Only register concrete classes (not abstract ones)
        if inspect.isabstract(cls):
            return

        # Enforce that subclasses define a non-empty tag
        if not cls.tag or not isinstance(cls.tag, str):
            raise TypeError(
                f"Concrete metric class {cls.__name__} must define a non-empty 'tag' class attribute"
            )

        # Check for duplicate tags
        if cls.tag in cls.metric_interfaces:
            raise ValueError(
                f"Metric tag '{cls.tag}' is already registered by {cls.metric_interfaces[cls.tag].__name__}"
            )

        cls.metric_interfaces[cls.tag] = cls

    @classmethod
    def get_all(cls) -> dict[str, type["BaseMetric"]]:
        """
        Returns the dictionary of all registered metric interfaces.

        This method dynamically imports all metric type modules from the 'types'
        directory to ensure all metric classes are registered via __init_subclass__.

        Returns:
            dict[str, type[Metric]]: Mapping of metric tags to their corresponding classes

        Raises:
            MetricTypeError: If there's an error importing metric type modules
        """
        # Get the types directory path
        types_dir = Path(__file__).parent / "types"

        # Import all metric type modules to trigger registration
        if types_dir.exists():
            for python_file in types_dir.glob("*.py"):
                if python_file.name != "__init__.py":
                    module_name = python_file.stem  # Get filename without extension
                    try:
                        importlib.import_module(
                            f"aiperf.services.records_manager.metrics.types.{module_name}"
                        )
                    except ImportError as err:
                        raise MetricTypeError(
                            f"Error importing metric type module '{module_name}'"
                        ) from err

        return cls.metric_interfaces

    @abstractmethod
    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Updates the metric value based on the provided record and dictionary of other metrics.

        Args:
            record (Optional[Record]): The record to update the metric with.
            metrics (Optional[dict[BaseMetric]]): A dictionary of other metrics that may be needed for calculation.
        """

    @abstractmethod
    def values(self) -> Any:
        """
        Returns the list of calculated metrics.
        """

    @abstractmethod
    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid for metric calculation.

        Raises:
            ValueError: If the record does not meet the required conditions.
        """

    def get_converted_metrics(self, unit: MetricTimeType) -> list[Any]:
        if not isinstance(unit, MetricTimeType):
            raise MetricTypeError("Invalid metric time type for conversion.")

        scale_factor = self._get_conversion_factor(self.unit, unit)

        return [metric / 10**scale_factor for metric in self.values()]

    def _check_metrics(self, metrics: dict[str, "BaseMetric"]) -> None:
        """
        Validates that the required dependent metrics are available.

        Raises:
            ValueError: If required metrics are missing.
        """
        if not metrics:
            raise ValueError("Metrics dictionary is missing.")

        for tag in self.required_metrics:
            if tag not in metrics:
                raise ValueError(f"Missing required metric: '{tag}'")

    def _get_conversion_factor(
        self, from_unit: MetricTimeType, to_unit: MetricTimeType
    ) -> int:
        unit_scales = {
            MetricTimeType.NANOSECONDS: 9,
            MetricTimeType.MILLISECONDS: 3,
            MetricTimeType.SECONDS: 0,
        }

        return unit_scales[from_unit] - unit_scales[to_unit]

    def _require_valid_record(self, record: ParsedResponseRecord) -> None:
        """
        Ensures the given record is not None and is marked as valid.

        Raises:
            ValueError: If the record is None or invalid.
        """
        if not record or not record.valid:
            raise ValueError("Invalid Record")

__init_subclass__(**kwargs)

This method is called when a class is subclassed from Metric. It automatically registers the subclass in the metric_interfaces dictionary using the tag class attribute. The tag attribute must be a non-empty string that uniquely identifies the metric type. Only concrete (non-abstract) classes will be registered.

Source code in aiperf/services/records_manager/metrics/base_metric.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init_subclass__(cls, **kwargs):
    """
    This method is called when a class is subclassed from Metric.
    It automatically registers the subclass in the metric_interfaces
    dictionary using the `tag` class attribute.
    The `tag` attribute must be a non-empty string that uniquely identifies the
    metric type. Only concrete (non-abstract) classes will be registered.
    """

    super().__init_subclass__(**kwargs)

    # Only register concrete classes (not abstract ones)
    if inspect.isabstract(cls):
        return

    # Enforce that subclasses define a non-empty tag
    if not cls.tag or not isinstance(cls.tag, str):
        raise TypeError(
            f"Concrete metric class {cls.__name__} must define a non-empty 'tag' class attribute"
        )

    # Check for duplicate tags
    if cls.tag in cls.metric_interfaces:
        raise ValueError(
            f"Metric tag '{cls.tag}' is already registered by {cls.metric_interfaces[cls.tag].__name__}"
        )

    cls.metric_interfaces[cls.tag] = cls

get_all() classmethod

Returns the dictionary of all registered metric interfaces.

This method dynamically imports all metric type modules from the 'types' directory to ensure all metric classes are registered via init_subclass.

Returns:

Type Description
dict[str, type[BaseMetric]]

dict[str, type[Metric]]: Mapping of metric tags to their corresponding classes

Raises:

Type Description
MetricTypeError

If there's an error importing metric type modules

Source code in aiperf/services/records_manager/metrics/base_metric.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@classmethod
def get_all(cls) -> dict[str, type["BaseMetric"]]:
    """
    Returns the dictionary of all registered metric interfaces.

    This method dynamically imports all metric type modules from the 'types'
    directory to ensure all metric classes are registered via __init_subclass__.

    Returns:
        dict[str, type[Metric]]: Mapping of metric tags to their corresponding classes

    Raises:
        MetricTypeError: If there's an error importing metric type modules
    """
    # Get the types directory path
    types_dir = Path(__file__).parent / "types"

    # Import all metric type modules to trigger registration
    if types_dir.exists():
        for python_file in types_dir.glob("*.py"):
            if python_file.name != "__init__.py":
                module_name = python_file.stem  # Get filename without extension
                try:
                    importlib.import_module(
                        f"aiperf.services.records_manager.metrics.types.{module_name}"
                    )
                except ImportError as err:
                    raise MetricTypeError(
                        f"Error importing metric type module '{module_name}'"
                    ) from err

    return cls.metric_interfaces

update_value(record=None, metrics=None) abstractmethod

Updates the metric value based on the provided record and dictionary of other metrics.

Parameters:

Name Type Description Default
record Optional[Record]

The record to update the metric with.

None
metrics Optional[dict[BaseMetric]]

A dictionary of other metrics that may be needed for calculation.

None
Source code in aiperf/services/records_manager/metrics/base_metric.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@abstractmethod
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Updates the metric value based on the provided record and dictionary of other metrics.

    Args:
        record (Optional[Record]): The record to update the metric with.
        metrics (Optional[dict[BaseMetric]]): A dictionary of other metrics that may be needed for calculation.
    """

values() abstractmethod

Returns the list of calculated metrics.

Source code in aiperf/services/records_manager/metrics/base_metric.py
103
104
105
106
107
@abstractmethod
def values(self) -> Any:
    """
    Returns the list of calculated metrics.
    """

aiperf.services.records_manager.metrics.types.benchmark_duration_metric

BenchmarkDurationMetric

Bases: BaseMetric

Post-processor for calculating the Benchmark Duration metric.

Source code in aiperf/services/records_manager/metrics/types/benchmark_duration_metric.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class BenchmarkDurationMetric(BaseMetric):
    """
    Post-processor for calculating the Benchmark Duration metric.
    """

    tag = "benchmark_duration"
    unit = MetricTimeType.NANOSECONDS
    larger_is_better = False
    header = "Benchmark Duration"
    type = MetricType.METRIC_OF_METRICS
    required_metrics: set[str] = {MinRequestMetric.tag, MaxResponseMetric.tag}

    def __init__(self):
        self.metric: float = 0.0

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        self._check_metrics(metrics)
        min_req_time = metrics[MinRequestMetric.tag].values()
        max_res_time = metrics[MaxResponseMetric.tag].values()
        benchmark_duration = max_res_time - min_req_time
        self.metric = benchmark_duration

    def values(self) -> float:
        """
        Returns the BenchmarkDuration metric.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        pass

values()

Returns the BenchmarkDuration metric.

Source code in aiperf/services/records_manager/metrics/types/benchmark_duration_metric.py
40
41
42
43
44
def values(self) -> float:
    """
    Returns the BenchmarkDuration metric.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.input_sequence_length_metric

InputSequenceLengthMetric

Bases: BaseMetric

Post-processor for calculating Input Sequence Length (ISL) metrics from records.

Source code in aiperf/services/records_manager/metrics/types/input_sequence_length_metric.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class InputSequenceLengthMetric(BaseMetric):
    """
    Post-processor for calculating Input Sequence Length (ISL) metrics from records.
    """

    tag = "isl"
    unit = None
    larger_is_better = False
    header = "Input Sequence Length"
    type = MetricType.METRIC_OF_RECORDS
    streaming_only = False
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: list[int] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ):
        self._check_record(record)
        input_token_count = record.input_token_count
        self.metric.append(input_token_count)

    def values(self) -> list[int]:
        """
        Returns the list of Input Sequence Length (ISL) metrics.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord):
        """
        Checks if the record is valid for ISL calculation.

        Raises:
            ValueError: If the record is not valid or doesn't have input_token_count.
        """
        self._require_valid_record(record)
        if record.input_token_count is None:
            raise ValueError("Input Token Count is not available for the record.")

values()

Returns the list of Input Sequence Length (ISL) metrics.

Source code in aiperf/services/records_manager/metrics/types/input_sequence_length_metric.py
34
35
36
37
38
def values(self) -> list[int]:
    """
    Returns the list of Input Sequence Length (ISL) metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.inter_token_latency_metric

InterTokenLatencyMetric

Bases: BaseMetric

Post Processor for calculating Inter Token Latency Metric from records.

Source code in aiperf/services/records_manager/metrics/types/inter_token_latency_metric.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class InterTokenLatencyMetric(BaseMetric):
    """
    Post Processor for calculating Inter Token Latency Metric from records.
    """

    tag = "inter_token_latency"
    unit = MetricTimeType.MILLISECONDS
    larger_is_better = False
    header = "Inter Token Latency (ITL)"
    type = MetricType.METRIC_OF_BOTH
    streaming_only = True
    required_metrics: set[str] = {RequestLatencyMetric.tag, TTFTMetric.tag}

    def __init__(self):
        self.metric: list[float] = []
        self._current_index: int = 0

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ):
        self._check_record(record)
        self._check_metrics(metrics)
        request_latency = metrics[RequestLatencyMetric.tag].values()[
            self._current_index
        ]
        ttft = metrics[TTFTMetric.tag].values()[self._current_index]
        output_tokens = record.output_token_count
        itl = (request_latency - ttft) / (output_tokens - 1)
        self.metric.append(itl)
        self._current_index = self._current_index + 1

    def values(self) -> list[float]:
        """
        Returns the list of Inter Token Latency (ITL) metrics.
        """
        return self.metric

    def _check_record(self, record):
        self._require_valid_record(record)
        if record.output_token_count is None:
            raise ValueError("Output token count is not available for the record.")
        if record.output_token_count <= 1:
            raise ValueError(
                "Output token count must be greater than 1 for ITL calculation."
            )

values()

Returns the list of Inter Token Latency (ITL) metrics.

Source code in aiperf/services/records_manager/metrics/types/inter_token_latency_metric.py
46
47
48
49
50
def values(self) -> list[float]:
    """
    Returns the list of Inter Token Latency (ITL) metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.max_response_metric

MaxResponseMetric

Bases: BaseMetric

Post-processor for calculating the maximum response time stamp metric from records.

Source code in aiperf/services/records_manager/metrics/types/max_response_metric.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MaxResponseMetric(BaseMetric):
    """
    Post-processor for calculating the maximum response time stamp metric from records.
    """

    tag = "max_response"
    unit = MetricTimeType.NANOSECONDS
    type = MetricType.METRIC_OF_RECORDS
    larger_is_better = False
    header = "Maximum Response Timestamp"
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: float = 0

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Adds a new record and calculates the maximum response timestamp metric.

        """
        self._check_record(record)
        if record.responses[-1].perf_ns > self.metric:
            self.metric = record.responses[-1].perf_ns

    def values(self) -> float:
        """
        Returns the Max Response Timestamp metric.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid for calculations.
        """
        self._require_valid_record(record)

update_value(record=None, metrics=None)

Adds a new record and calculates the maximum response timestamp metric.

Source code in aiperf/services/records_manager/metrics/types/max_response_metric.py
23
24
25
26
27
28
29
30
31
32
33
34
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Adds a new record and calculates the maximum response timestamp metric.

    """
    self._check_record(record)
    if record.responses[-1].perf_ns > self.metric:
        self.metric = record.responses[-1].perf_ns

values()

Returns the Max Response Timestamp metric.

Source code in aiperf/services/records_manager/metrics/types/max_response_metric.py
36
37
38
39
40
def values(self) -> float:
    """
    Returns the Max Response Timestamp metric.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.min_request_metric

MinRequestMetric

Bases: BaseMetric

Post-processor for calculating the minimum request time stamp metric from records.

Source code in aiperf/services/records_manager/metrics/types/min_request_metric.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class MinRequestMetric(BaseMetric):
    """
    Post-processor for calculating the minimum request time stamp metric from records.
    """

    tag = "min_request"
    unit = MetricTimeType.NANOSECONDS
    type = MetricType.METRIC_OF_RECORDS
    larger_is_better = False
    header = "Minimum Request Timestamp"
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: float = float("inf")

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Adds a new record and calculates the minimum request timestamp metric.

        """
        self._check_record(record)
        if record.start_perf_ns < self.metric:
            self.metric = record.start_perf_ns

    def values(self) -> float:
        """
        Returns the Minimum Request Timestamp metric.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid for calculations.

        """
        self._require_valid_record(record)

update_value(record=None, metrics=None)

Adds a new record and calculates the minimum request timestamp metric.

Source code in aiperf/services/records_manager/metrics/types/min_request_metric.py
23
24
25
26
27
28
29
30
31
32
33
34
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Adds a new record and calculates the minimum request timestamp metric.

    """
    self._check_record(record)
    if record.start_perf_ns < self.metric:
        self.metric = record.start_perf_ns

values()

Returns the Minimum Request Timestamp metric.

Source code in aiperf/services/records_manager/metrics/types/min_request_metric.py
36
37
38
39
40
def values(self) -> float:
    """
    Returns the Minimum Request Timestamp metric.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.output_sequence_length_metric

OutputSequenceLengthMetric

Bases: BaseMetric

Post-processor for calculating Output Sequence Length (OSL) metrics from records.

Source code in aiperf/services/records_manager/metrics/types/output_sequence_length_metric.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class OutputSequenceLengthMetric(BaseMetric):
    """
    Post-processor for calculating Output Sequence Length (OSL) metrics from records.
    """

    tag = "osl"
    unit = None
    larger_is_better = False
    header = "Output Sequence Length"
    type = MetricType.METRIC_OF_RECORDS
    streaming_only = False
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: list[int] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ):
        self._check_record(record)
        self.metric.append(record.output_token_count)

    def values(self):
        """
        Returns the list of Output Sequence Length (OSL) metrics.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord):
        """
        Checks if the record is valid for OSL calculation.

        Raises:
            ValueError: If record is not valid or output_token_count is missing.
        """
        self._require_valid_record(record)
        if record.output_token_count is None:
            raise ValueError("Output token count is missing in the record.")

values()

Returns the list of Output Sequence Length (OSL) metrics.

Source code in aiperf/services/records_manager/metrics/types/output_sequence_length_metric.py
33
34
35
36
37
def values(self):
    """
    Returns the list of Output Sequence Length (OSL) metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.output_token_throughput_per_user_metric

OutputTokenThroughputPerUserMetric

Bases: BaseMetric

Post Processor for calculating Output Token Throughput per user metrics from records.

Source code in aiperf/services/records_manager/metrics/types/output_token_throughput_per_user_metric.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class OutputTokenThroughputPerUserMetric(BaseMetric):
    """
    Post Processor for calculating Output Token Throughput per user metrics from records.
    """

    tag = "output_token_throughput_per_user"
    unit = MetricTimeType.SECONDS
    larger_is_better = True
    header = "Output Token Throughput Per User"
    type = MetricType.METRIC_OF_METRICS
    streaming_only = True
    required_metrics: set[str] = {InterTokenLatencyMetric.tag}

    def __init__(self):
        self.metric: list[float] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ):
        self._check_metrics(metrics)
        inter_token_latencies = metrics["inter_token_latency"].values()
        for inter_token_latency in inter_token_latencies:
            inter_token_latency_s = inter_token_latency / NANOS_PER_SECOND
            if inter_token_latency_s <= 0:
                raise ValueError("Inter-token latency must be greater than 0.")
            self.metric.append(1 / inter_token_latency_s)

    def values(self):
        """
        Returns the list of Output Token Throughput Per User metrics.
        """
        return self.metric

    def _check_record(self, record):
        pass

values()

Returns the list of Output Token Throughput Per User metrics.

Source code in aiperf/services/records_manager/metrics/types/output_token_throughput_per_user_metric.py
42
43
44
45
46
def values(self):
    """
    Returns the list of Output Token Throughput Per User metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.request_count_metric

RequestCountMetric

Bases: BaseMetric

Post-processor for counting the number of valid requests.

Source code in aiperf/services/records_manager/metrics/types/request_count_metric.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class RequestCountMetric(BaseMetric):
    """
    Post-processor for counting the number of valid requests.
    """

    tag = "request_count"
    unit = None
    larger_is_better = True
    header = "Request Count"
    type = MetricType.METRIC_OF_RECORDS
    streaming_only = False
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: int = 0

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ) -> None:
        self._check_record(record)
        self.metric += 1

    def values(self) -> int:
        """
        Returns the Request Count metric.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        self._require_valid_record(record)

values()

Returns the Request Count metric.

Source code in aiperf/services/records_manager/metrics/types/request_count_metric.py
33
34
35
36
37
def values(self) -> int:
    """
    Returns the Request Count metric.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.request_latency_metric

RequestLatencyMetric

Bases: BaseMetric

Post-processor for calculating Request Latency metrics from records.

Source code in aiperf/services/records_manager/metrics/types/request_latency_metric.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class RequestLatencyMetric(BaseMetric):
    """
    Post-processor for calculating Request Latency metrics from records.
    """

    tag = "request_latency"
    unit = MetricTimeType.NANOSECONDS
    type = MetricType.METRIC_OF_RECORDS
    larger_is_better = False
    header = "Request Latency"
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: list[int] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Adds a new record and calculates the Request Latency metric.

        This method extracts the request and last response timestamps, calculates the differences in time, and
        appends the result to the metric list.
        """
        self._check_record(record)
        request_ts = record.start_perf_ns
        final_response_ts = record.responses[-1].perf_ns
        request_latency = final_response_ts - request_ts
        self.metric.append(request_latency)

    def values(self) -> list[int]:
        """
        Returns the list of Request Latency metrics.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        self._require_valid_record(record)

update_value(record=None, metrics=None)

Adds a new record and calculates the Request Latency metric.

This method extracts the request and last response timestamps, calculates the differences in time, and appends the result to the metric list.

Source code in aiperf/services/records_manager/metrics/types/request_latency_metric.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Adds a new record and calculates the Request Latency metric.

    This method extracts the request and last response timestamps, calculates the differences in time, and
    appends the result to the metric list.
    """
    self._check_record(record)
    request_ts = record.start_perf_ns
    final_response_ts = record.responses[-1].perf_ns
    request_latency = final_response_ts - request_ts
    self.metric.append(request_latency)

values()

Returns the list of Request Latency metrics.

Source code in aiperf/services/records_manager/metrics/types/request_latency_metric.py
40
41
42
43
44
def values(self) -> list[int]:
    """
    Returns the list of Request Latency metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.request_throughput_metric

RequestThroughputMetric

Bases: BaseMetric

Post Processor for calculating Request throughput metrics from records.

Source code in aiperf/services/records_manager/metrics/types/request_throughput_metric.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class RequestThroughputMetric(BaseMetric):
    """
    Post Processor for calculating Request throughput metrics from records.
    """

    tag = "request_throughput"
    unit = MetricTimeType.SECONDS
    larger_is_better = True
    header = "Request Throughput"
    type = MetricType.METRIC_OF_METRICS
    streaming_only = False
    required_metrics: set[str] = {RequestCountMetric.tag, BenchmarkDurationMetric.tag}

    def __init__(self):
        self.total_requests: int = 0
        self.metric: float = 0.0

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict[str, "BaseMetric"] | None = None,
    ) -> None:
        self._check_metrics(metrics)
        total_requests = metrics[RequestCountMetric.tag].values()
        benchmark_duration = metrics[BenchmarkDurationMetric.tag].values()
        self.metric = total_requests / (benchmark_duration / NANOS_PER_SECOND)

    def values(self) -> float:
        """
        Returns the Request Throughput metric.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid.

        Raises:
            ValueError: If the record is None or is invalid.
        """
        self._require_valid_record(record)

values()

Returns the Request Throughput metric.

Source code in aiperf/services/records_manager/metrics/types/request_throughput_metric.py
42
43
44
45
46
def values(self) -> float:
    """
    Returns the Request Throughput metric.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.ttft_metric

TTFTMetric

Bases: BaseMetric

Post-processor for calculating Time to First Token (TTFT) metrics from records.

Source code in aiperf/services/records_manager/metrics/types/ttft_metric.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class TTFTMetric(BaseMetric):
    """
    Post-processor for calculating Time to First Token (TTFT) metrics from records.
    """

    tag = "ttft"
    unit = MetricTimeType.NANOSECONDS
    larger_is_better = False
    header = "Time to First Token (TTFT)"
    type = MetricType.METRIC_OF_RECORDS
    streaming_only = True
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: list[int] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Adds a new record and calculates the Time To First Token (TTFT) metric.

        This method extracts the timestamp from the request and the first response in the given
        RequestRecord object, computes the difference (TTFT), and appends the result to the metric list.
        """
        self._check_record(record)
        request_ts = record.request.start_perf_ns
        response_ts = record.responses[0].perf_ns
        ttft = response_ts - request_ts
        self.metric.append(ttft)

    def values(self) -> list[int]:
        """
        Returns the list of Time to First Token (TTFT) metrics.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid for TTFT calculation.

        Raises:
            ValueError: If record is None or record is not valid
        """
        self._require_valid_record(record)

update_value(record=None, metrics=None)

Adds a new record and calculates the Time To First Token (TTFT) metric.

This method extracts the timestamp from the request and the first response in the given RequestRecord object, computes the difference (TTFT), and appends the result to the metric list.

Source code in aiperf/services/records_manager/metrics/types/ttft_metric.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Adds a new record and calculates the Time To First Token (TTFT) metric.

    This method extracts the timestamp from the request and the first response in the given
    RequestRecord object, computes the difference (TTFT), and appends the result to the metric list.
    """
    self._check_record(record)
    request_ts = record.request.start_perf_ns
    response_ts = record.responses[0].perf_ns
    ttft = response_ts - request_ts
    self.metric.append(ttft)

values()

Returns the list of Time to First Token (TTFT) metrics.

Source code in aiperf/services/records_manager/metrics/types/ttft_metric.py
41
42
43
44
45
def values(self) -> list[int]:
    """
    Returns the list of Time to First Token (TTFT) metrics.
    """
    return self.metric

aiperf.services.records_manager.metrics.types.ttst_metric

TTSTMetric

Bases: BaseMetric

Post-processor for calculating Time to Second Token (TTST) metrics from records.

Source code in aiperf/services/records_manager/metrics/types/ttst_metric.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class TTSTMetric(BaseMetric):
    """
    Post-processor for calculating Time to Second Token (TTST) metrics from records.
    """

    tag = "ttst"
    unit = MetricTimeType.NANOSECONDS
    larger_is_better = False
    header = "Time to Second Token (TTST)"
    type = MetricType.METRIC_OF_RECORDS
    streaming_only = True
    required_metrics: set[str] = set()

    def __init__(self):
        self.metric: list[int] = []

    def update_value(
        self,
        record: ParsedResponseRecord | None = None,
        metrics: dict["BaseMetric"] | None = None,
    ) -> None:
        """
        Adds a new record and calculates the Time To Second Token (TTST) metric.

        This method extracts the timestamp from the first and second response in the given
        Record object, computes the difference (TTST), and appends the result to the metric list.
        """
        self._check_record(record)
        first_reponse_ts = record.responses[0].perf_ns
        second_response_ts = record.responses[1].perf_ns
        ttst = second_response_ts - first_reponse_ts
        self.metric.append(ttst)

    def values(self) -> list[int]:
        """
        Returns the list of Time to First Token (TTST) metrics.
        """
        return self.metric

    def _check_record(self, record: ParsedResponseRecord) -> None:
        """
        Checks if the record is valid for TTST calculation.

        Raises:
            ValueError: If the record does not have at least two responses.
        """
        self._require_valid_record(record)
        if len(record.responses) < 2:
            raise ValueError(
                "Record must have at least two responses to calculate TTST."
            )
        if record.responses[1].perf_ns < record.responses[0].perf_ns:
            raise ValueError(
                "Second response timestamp must be greater than or equal to the first response timestamp."
            )

update_value(record=None, metrics=None)

Adds a new record and calculates the Time To Second Token (TTST) metric.

This method extracts the timestamp from the first and second response in the given Record object, computes the difference (TTST), and appends the result to the metric list.

Source code in aiperf/services/records_manager/metrics/types/ttst_metric.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def update_value(
    self,
    record: ParsedResponseRecord | None = None,
    metrics: dict["BaseMetric"] | None = None,
) -> None:
    """
    Adds a new record and calculates the Time To Second Token (TTST) metric.

    This method extracts the timestamp from the first and second response in the given
    Record object, computes the difference (TTST), and appends the result to the metric list.
    """
    self._check_record(record)
    first_reponse_ts = record.responses[0].perf_ns
    second_response_ts = record.responses[1].perf_ns
    ttst = second_response_ts - first_reponse_ts
    self.metric.append(ttst)

values()

Returns the list of Time to First Token (TTST) metrics.

Source code in aiperf/services/records_manager/metrics/types/ttst_metric.py
41
42
43
44
45
def values(self) -> list[int]:
    """
    Returns the list of Time to First Token (TTST) metrics.
    """
    return self.metric

aiperf.services.records_manager.post_processors.metric_summary

MetricSummary

MetricSummary is a post-processor that generates a summary of metrics from the records. It processes the records to extract relevant metrics and returns them in a structured format.

Source code in aiperf/services/records_manager/post_processors/metric_summary.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@PostProcessorFactory.register(PostProcessorType.METRIC_SUMMARY)
class MetricSummary:
    """
    MetricSummary is a post-processor that generates a summary of metrics from the records.
    It processes the records to extract relevant metrics and returns them in a structured format.
    """

    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.logger.debug("Initializing MetricSummary post-processor")

        self._metrics = []
        for metric_cls in BaseMetric.get_all().values():
            self._metrics.append(metric_cls())

    def process(self, records: list[ParsedResponseRecord]) -> None:
        """
        Classifies and computes metrics in dependency order to ensure correctness.
        The metrics are categorized based on their dependency types:

        1. METRIC_OF_RECORDS:
            - Depend solely on each individual record.
            - Computed first, as they have no dependencies.

        2. METRIC_OF_BOTH:
            - Depend on both:
                - the current record, and
                - previously computed metrics (specifically, METRIC_OF_RECORDS).
            - Computed after all METRIC_OF_RECORDS have been processed.
            - Must not depend on other METRIC_OF_BOTH or METRIC_OF_METRICS.

        3. METRIC_OF_METRICS:
            - Computed based only on other metrics (not records).
            - May depend on any combination of:
                - METRIC_OF_RECORDS,
                - METRIC_OF_BOTH,
                - other METRIC_OF_METRICS (if dependency order is respected).
            - Computed using a dependency-resolution loop.

        This process ensures:
            - All metrics are computed exactly once, after dependencies are satisfied.
            - Misconfigured or cyclic dependencies will raise an explicit runtime error.
        """

        # METRIC_OF_RECORDS
        for record in records:
            for metric in self._metrics:
                if metric.type == MetricType.METRIC_OF_RECORDS:
                    metric.update_value(record=record)

        # METRIC_OF_BOTH
        for record in records:
            for metric in self._metrics:
                if metric.type == MetricType.METRIC_OF_BOTH:
                    metric.update_value(
                        record=record, metrics={m.tag: m for m in self._metrics}
                    )

        # METRIC_OF_METRICS
        # Precompute tags of all metrics already processed
        computed_tags = {
            m.tag
            for m in self._metrics
            if m.type in {MetricType.METRIC_OF_RECORDS, MetricType.METRIC_OF_BOTH}
        }

        remaining = [m for m in self._metrics if m.type == MetricType.METRIC_OF_METRICS]

        # Resolve dependencies: loop until all metrics are computed or a circular dependency is found
        while remaining:
            progress = False
            for metric in remaining[:]:
                # If required dependencies are all satisfied, compute this metric
                if metric.required_metrics.issubset(computed_tags):
                    metric.update_value(metrics={m.tag: m for m in self._metrics})
                    computed_tags.add(metric.tag)
                    remaining.remove(metric)
                    progress = True

            if not progress:
                # Circular dependencies
                missing = {m.tag: m.required_metrics - computed_tags for m in remaining}
                raise ValueError(
                    f"Circular or unsatisfiable dependencies detected in METRIC_OF_METRICS: {missing}"
                )

    def get_metrics_summary(self) -> list[MetricResult]:
        metrics_summary = []

        df = pd.DataFrame({metric.tag: metric.values() for metric in self._metrics})

        for metric in self._metrics:
            res: MetricResult = record_from_dataframe(df, metric)
            metrics_summary.append(res)
        return metrics_summary

process(records)

Classifies and computes metrics in dependency order to ensure correctness. The metrics are categorized based on their dependency types:

  1. METRIC_OF_RECORDS:

    • Depend solely on each individual record.
    • Computed first, as they have no dependencies.
  2. METRIC_OF_BOTH:

    • Depend on both:
      • the current record, and
      • previously computed metrics (specifically, METRIC_OF_RECORDS).
    • Computed after all METRIC_OF_RECORDS have been processed.
    • Must not depend on other METRIC_OF_BOTH or METRIC_OF_METRICS.
  3. METRIC_OF_METRICS:

    • Computed based only on other metrics (not records).
    • May depend on any combination of:
      • METRIC_OF_RECORDS,
      • METRIC_OF_BOTH,
      • other METRIC_OF_METRICS (if dependency order is respected).
    • Computed using a dependency-resolution loop.
This process ensures
  • All metrics are computed exactly once, after dependencies are satisfied.
  • Misconfigured or cyclic dependencies will raise an explicit runtime error.
Source code in aiperf/services/records_manager/post_processors/metric_summary.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def process(self, records: list[ParsedResponseRecord]) -> None:
    """
    Classifies and computes metrics in dependency order to ensure correctness.
    The metrics are categorized based on their dependency types:

    1. METRIC_OF_RECORDS:
        - Depend solely on each individual record.
        - Computed first, as they have no dependencies.

    2. METRIC_OF_BOTH:
        - Depend on both:
            - the current record, and
            - previously computed metrics (specifically, METRIC_OF_RECORDS).
        - Computed after all METRIC_OF_RECORDS have been processed.
        - Must not depend on other METRIC_OF_BOTH or METRIC_OF_METRICS.

    3. METRIC_OF_METRICS:
        - Computed based only on other metrics (not records).
        - May depend on any combination of:
            - METRIC_OF_RECORDS,
            - METRIC_OF_BOTH,
            - other METRIC_OF_METRICS (if dependency order is respected).
        - Computed using a dependency-resolution loop.

    This process ensures:
        - All metrics are computed exactly once, after dependencies are satisfied.
        - Misconfigured or cyclic dependencies will raise an explicit runtime error.
    """

    # METRIC_OF_RECORDS
    for record in records:
        for metric in self._metrics:
            if metric.type == MetricType.METRIC_OF_RECORDS:
                metric.update_value(record=record)

    # METRIC_OF_BOTH
    for record in records:
        for metric in self._metrics:
            if metric.type == MetricType.METRIC_OF_BOTH:
                metric.update_value(
                    record=record, metrics={m.tag: m for m in self._metrics}
                )

    # METRIC_OF_METRICS
    # Precompute tags of all metrics already processed
    computed_tags = {
        m.tag
        for m in self._metrics
        if m.type in {MetricType.METRIC_OF_RECORDS, MetricType.METRIC_OF_BOTH}
    }

    remaining = [m for m in self._metrics if m.type == MetricType.METRIC_OF_METRICS]

    # Resolve dependencies: loop until all metrics are computed or a circular dependency is found
    while remaining:
        progress = False
        for metric in remaining[:]:
            # If required dependencies are all satisfied, compute this metric
            if metric.required_metrics.issubset(computed_tags):
                metric.update_value(metrics={m.tag: m for m in self._metrics})
                computed_tags.add(metric.tag)
                remaining.remove(metric)
                progress = True

        if not progress:
            # Circular dependencies
            missing = {m.tag: m.required_metrics - computed_tags for m in remaining}
            raise ValueError(
                f"Circular or unsatisfiable dependencies detected in METRIC_OF_METRICS: {missing}"
            )

record_from_dataframe(df, metric)

Create a Record from a DataFrame.

Source code in aiperf/services/records_manager/post_processors/metric_summary.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def record_from_dataframe(df: pd.DataFrame, metric: BaseMetric) -> MetricResult:
    """Create a Record from a DataFrame."""

    column = df[metric.tag]
    quantiles = column.quantile([0.01, 0.05, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99])

    return MetricResult(
        tag=metric.tag,
        header=metric.header,
        unit=metric.unit.short_name() if metric.unit else "",
        avg=column.mean(),
        min=column.min(),
        max=column.max(),
        p1=quantiles[0.01],
        p5=quantiles[0.05],
        p25=quantiles[0.25],
        p50=quantiles[0.50],
        p75=quantiles[0.75],
        p90=quantiles[0.90],
        p95=quantiles[0.95],
        p99=quantiles[0.99],
        std=column.std(),
        count=int(column.count()),
        streaming_only=metric.streaming_only,
    )

aiperf.services.records_manager.post_processors.streaming_post_processor

BaseStreamingPostProcessor

Bases: AIPerfLifecycleMixin, ABC

BaseStreamingPostProcessor is a base class for all classes that wish to stream the incoming ParsedResponseRecords.

Source code in aiperf/services/records_manager/post_processors/streaming_post_processor.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class BaseStreamingPostProcessor(AIPerfLifecycleMixin, ABC):
    """
    BaseStreamingPostProcessor is a base class for all classes that wish to stream the incoming
    ParsedResponseRecords.
    """

    def __init__(
        self,
        pub_client: PubClientProtocol,
        sub_client: SubClientProtocol,
        service_id: str,
        service_config: ServiceConfig,
        user_config: UserConfig,
        max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
        **kwargs,
    ) -> None:
        self.service_id = service_id
        self.user_config = user_config
        self.service_config = service_config
        self.pub_client = pub_client
        self.sub_client = sub_client
        super().__init__(
            pub_client=pub_client,
            sub_client=sub_client,
            user_config=user_config,
            service_config=service_config,
            **kwargs,
        )
        self.info(
            lambda: f"Created streaming post processor: {self.__class__.__name__} with max_queue_size: {max_queue_size:,}"
        )
        self.records_queue: asyncio.Queue[ParsedResponseRecord] = asyncio.Queue(
            maxsize=max_queue_size
        )

    @aiperf_task
    async def _stream_records_task(self) -> None:
        """Task that streams records from the queue to the post processor's stream_record method."""
        while True:
            try:
                record = await self.records_queue.get()
                await self.stream_record(record)
                self.records_queue.task_done()
            except asyncio.CancelledError:
                break

    @abstractmethod
    async def stream_record(self, record: ParsedResponseRecord) -> None:
        """Handle the incoming record. This method should be implemented by the subclass."""
        raise NotImplementedError(
            "BaseStreamingPostProcessor.stream_record method must be implemented by the subclass."
        )

stream_record(record) abstractmethod async

Handle the incoming record. This method should be implemented by the subclass.

Source code in aiperf/services/records_manager/post_processors/streaming_post_processor.py
62
63
64
65
66
67
@abstractmethod
async def stream_record(self, record: ParsedResponseRecord) -> None:
    """Handle the incoming record. This method should be implemented by the subclass."""
    raise NotImplementedError(
        "BaseStreamingPostProcessor.stream_record method must be implemented by the subclass."
    )

aiperf.services.records_manager.records_manager

DEFAULT_MAX_RECORDS_CONCURRENCY = 100000 module-attribute

The default maximum concurrency for the records manager pull client.

RecordsManager

Bases: BaseComponentService

The RecordsManager service is primarily responsible for holding the results returned from the workers.

Source code in aiperf/services/records_manager/records_manager.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@ServiceFactory.register(ServiceType.RECORDS_MANAGER)
class RecordsManager(BaseComponentService):
    """
    The RecordsManager service is primarily responsible for holding the
    results returned from the workers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.user_config: UserConfig | None = None
        self.configured_event = asyncio.Event()

        # TODO: we do not want to keep all the data forever
        self.records: deque[ParsedResponseRecord] = deque()
        self.error_records: deque[ParsedResponseRecord] = deque()

        self.total_expected_requests: int | None = None
        self.error_records_count: int = 0
        self.records_count: int = 0
        self.final_request_count: int | None = None

        # Track per-worker statistics
        self.worker_success_counts: dict[str, int] = {}
        self.worker_error_counts: dict[str, int] = {}

        self.start_time_ns: int = time.time_ns()
        self.end_time_ns: int | None = None

        self.streaming_post_processors: list[BaseStreamingPostProcessor] = []

        self.response_results_client: PullClientProtocol = (
            self.comms.create_pull_client(
                CommunicationClientAddressType.RECORDS,
                bind=True,
            )
        )

    @property
    def service_type(self) -> ServiceType:
        """The type of service."""
        return ServiceType.RECORDS_MANAGER

    @on_init
    async def _initialize(self) -> None:
        """Initialize records manager-specific components."""
        self.debug("Initializing records manager")
        self.register_command_callback(
            CommandType.PROCESS_RECORDS,
            self.process_records,
        )
        await self.response_results_client.register_pull_callback(
            message_type=MessageType.PARSED_INFERENCE_RESULTS,
            callback=self._on_parsed_inference_results,
            max_concurrency=DEFAULT_MAX_RECORDS_CONCURRENCY,
        )

        await self.sub_client.subscribe(
            MessageType.CREDIT_PHASE_START,
            self._on_credit_phase_start,
        )

    @on_init
    async def _initialize_streaming_post_processors(self) -> None:
        """Initialize the streaming post processors and start their lifecycle."""
        for streamer_type in StreamingPostProcessorFactory.get_all_class_types():
            streamer = StreamingPostProcessorFactory.create_instance(
                class_type=streamer_type,
                pub_client=self.pub_client,
                sub_client=self.sub_client,
                service_id=self.service_id,
                service_config=self.service_config,
                user_config=self.user_config,
            )
            self.debug(f"Initializing streaming post processor: {streamer_type}")
            self.streaming_post_processors.append(streamer)
            self.debug(
                lambda streamer=streamer: f"Starting lifecycle for {streamer.__class__.__name__}"
            )
            await streamer.run_async()

    @on_stop
    async def _stop_streaming_post_processors(self) -> None:
        """Stop the streaming post processors."""
        await asyncio.gather(
            *[streamer.shutdown() for streamer in self.streaming_post_processors]
        )

    @on_cleanup
    async def _cleanup(self) -> None:
        """Cleanup the records manager."""
        await asyncio.gather(
            *[
                streamer.wait_for_shutdown()
                for streamer in self.streaming_post_processors
            ]
        )

    @aiperf_task
    async def _report_records_task(self) -> None:
        """Report the records."""
        while not self.stop_event.is_set():
            await asyncio.sleep(ServiceDefaults.PROGRESS_REPORT_INTERVAL_SECONDS)
            await self.publish_processing_stats()

    async def publish_processing_stats(self) -> None:
        """Publish the profile stats."""
        await self.pub_client.publish(
            RecordsProcessingStatsMessage(
                service_id=self.service_id,
                processing_stats=PhaseProcessingStats(
                    processed=self.records_count,
                    errors=self.error_records_count,
                    total_expected_requests=self.total_expected_requests,
                ),
                worker_stats={
                    worker_id: PhaseProcessingStats(
                        processed=self.worker_success_counts[worker_id],
                        errors=self.worker_error_counts[worker_id],
                    )
                    for worker_id in self.worker_success_counts
                },
                request_ns=time.time_ns(),
            ),
        )

    async def _on_credit_phase_start(self, message: CreditPhaseStartMessage) -> None:
        """Handle a credit phase start message."""
        if message.phase == CreditPhase.PROFILING:
            self.total_expected_requests = message.total_expected_requests

    async def _on_credit_phase_complete(
        self, message: CreditPhaseCompleteMessage
    ) -> None:
        """Handle a credit phase complete message."""
        if message.phase == CreditPhase.PROFILING:
            self.final_request_count = message.completed

    async def _on_parsed_inference_results(
        self, message: ParsedInferenceResultsMessage
    ) -> None:
        """Handle a parsed inference results message."""
        self.trace(lambda: f"Received parsed inference results: {message}")

        if message.record.request.credit_phase != CreditPhase.PROFILING:
            self.debug(
                lambda: f"Skipping non-profiling record: {message.record.request.credit_phase}"
            )
            return

        # Stream the record to all of the streaming post processors
        for streamer in self.streaming_post_processors:
            try:
                self.debug(
                    lambda name=streamer.__class__.__name__: f"Putting record into queue for streamer {name}"
                )
                streamer.records_queue.put_nowait(message.record)
            except asyncio.QueueFull:
                self.error(
                    f"Streaming post processor {streamer.__class__.__name__} is unable to keep up with the rate of incoming records."
                )
                self.warning(
                    f"Waiting for queue to be available for streamer {streamer.__class__.__name__}. This will cause back pressure on the records manager."
                )
                await streamer.records_queue.put(message.record)

        worker_id = message.record.worker_id
        if worker_id not in self.worker_success_counts:
            self.worker_success_counts[worker_id] = 0
        if worker_id not in self.worker_error_counts:
            self.worker_error_counts[worker_id] = 0

        if message.record.request.has_error:
            self.warning(lambda: f"Received error inference results: {message}")
            # TODO: we do not want to keep all the data forever
            self.error_records.append(message.record)
            self.worker_error_counts[worker_id] += 1
            self.error_records_count += 1
        elif message.record.request.valid:
            # TODO: we do not want to keep all the data forever
            self.records.append(message.record)
            self.worker_success_counts[worker_id] += 1
            self.records_count += 1
        else:
            self.warning(lambda: f"Received invalid inference results: {message}")
            # TODO: we do not want to keep all the data forever
            self.error_records.append(message.record)
            self.worker_error_counts[worker_id] += 1
            self.error_records_count += 1

        if (
            self.final_request_count is not None
            and self.records_count >= self.final_request_count
        ):
            self.info(
                lambda: f"Processed {self.records_count} requests and {self.error_records_count} errors."
            )
            await self.publish_processing_stats()
            # TODO: Publish PROFILE_RESULTS_COMPLETE message

    async def get_error_summary(self) -> list[ErrorDetailsCount]:
        """Generate a summary of the error records."""
        summary: dict[ErrorDetails, int] = {}
        for record in self.error_records:
            if record.request.error is None:
                continue
            if record.request.error not in summary:
                summary[record.request.error] = 0
            summary[record.request.error] += 1

        return [
            ErrorDetailsCount(error_details=error_details, count=count)
            for error_details, count in summary.items()
        ]

    async def process_records(self, message: CommandMessage) -> None:
        """Process the records.

        This method is called when the records manager receives a command to process the records.
        """
        self.notice(lambda: f"Processing records: {message}")
        self.was_cancelled = (
            message.data.cancelled
            if isinstance(message.data, ProcessRecordsCommandData)
            else False
        )
        self.end_time_ns = time.time_ns()
        # TODO: Implement records processing
        self.info(
            lambda: f"Processed {len(self.records)} successful records and {len(self.error_records)} error records"
        )

        profile_results = await self.post_process_records()
        self.info(lambda: f"Profile results: {profile_results}")

        if profile_results:
            await self.pub_client.publish(
                profile_results,
            )

            if self.user_config:
                await ExporterManager(
                    results=profile_results, input_config=self.user_config
                ).export_all()

        else:
            self.error("No profile results to publish")
            await self.pub_client.publish(
                ProfileResultsMessage(
                    service_id=self.service_id,
                    total=0,
                    completed=0,
                    start_ns=self.start_time_ns,
                    end_ns=self.end_time_ns,
                    records=[],
                    errors_by_type=[],
                    was_cancelled=self.was_cancelled,
                ),
            )

    async def post_process_records(self) -> ProfileResultsMessage | None:
        """Post process the records."""
        self.trace("Post processing records")

        if not self.records:
            self.warning("No successful records to process")
            return ProfileResultsMessage(
                service_id=self.service_id,
                total=len(self.records),
                completed=len(self.records) + len(self.error_records),
                start_ns=self.start_time_ns or time.time_ns(),
                end_ns=self.end_time_ns or time.time_ns(),
                records=[],
                errors_by_type=await self.get_error_summary(),
                was_cancelled=self.was_cancelled,
            )

        self.trace(
            lambda: f"Token counts: {', '.join([str(r.output_token_count) for r in self.records])}"
        )
        metric_summary = MetricSummary()
        metric_summary.process(list(self.records))
        metrics_summary = metric_summary.get_metrics_summary()

        # Create and return ProfileResultsMessage
        return ProfileResultsMessage(
            service_id=self.service_id,
            total=len(self.records),
            completed=len(self.records) + len(self.error_records),
            start_ns=self.start_time_ns or time.time_ns(),
            end_ns=self.end_time_ns or time.time_ns(),
            records=metrics_summary,
            errors_by_type=await self.get_error_summary(),
            was_cancelled=self.was_cancelled,
        )

service_type property

The type of service.

get_error_summary() async

Generate a summary of the error records.

Source code in aiperf/services/records_manager/records_manager.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
async def get_error_summary(self) -> list[ErrorDetailsCount]:
    """Generate a summary of the error records."""
    summary: dict[ErrorDetails, int] = {}
    for record in self.error_records:
        if record.request.error is None:
            continue
        if record.request.error not in summary:
            summary[record.request.error] = 0
        summary[record.request.error] += 1

    return [
        ErrorDetailsCount(error_details=error_details, count=count)
        for error_details, count in summary.items()
    ]

post_process_records() async

Post process the records.

Source code in aiperf/services/records_manager/records_manager.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
async def post_process_records(self) -> ProfileResultsMessage | None:
    """Post process the records."""
    self.trace("Post processing records")

    if not self.records:
        self.warning("No successful records to process")
        return ProfileResultsMessage(
            service_id=self.service_id,
            total=len(self.records),
            completed=len(self.records) + len(self.error_records),
            start_ns=self.start_time_ns or time.time_ns(),
            end_ns=self.end_time_ns or time.time_ns(),
            records=[],
            errors_by_type=await self.get_error_summary(),
            was_cancelled=self.was_cancelled,
        )

    self.trace(
        lambda: f"Token counts: {', '.join([str(r.output_token_count) for r in self.records])}"
    )
    metric_summary = MetricSummary()
    metric_summary.process(list(self.records))
    metrics_summary = metric_summary.get_metrics_summary()

    # Create and return ProfileResultsMessage
    return ProfileResultsMessage(
        service_id=self.service_id,
        total=len(self.records),
        completed=len(self.records) + len(self.error_records),
        start_ns=self.start_time_ns or time.time_ns(),
        end_ns=self.end_time_ns or time.time_ns(),
        records=metrics_summary,
        errors_by_type=await self.get_error_summary(),
        was_cancelled=self.was_cancelled,
    )

process_records(message) async

Process the records.

This method is called when the records manager receives a command to process the records.

Source code in aiperf/services/records_manager/records_manager.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
async def process_records(self, message: CommandMessage) -> None:
    """Process the records.

    This method is called when the records manager receives a command to process the records.
    """
    self.notice(lambda: f"Processing records: {message}")
    self.was_cancelled = (
        message.data.cancelled
        if isinstance(message.data, ProcessRecordsCommandData)
        else False
    )
    self.end_time_ns = time.time_ns()
    # TODO: Implement records processing
    self.info(
        lambda: f"Processed {len(self.records)} successful records and {len(self.error_records)} error records"
    )

    profile_results = await self.post_process_records()
    self.info(lambda: f"Profile results: {profile_results}")

    if profile_results:
        await self.pub_client.publish(
            profile_results,
        )

        if self.user_config:
            await ExporterManager(
                results=profile_results, input_config=self.user_config
            ).export_all()

    else:
        self.error("No profile results to publish")
        await self.pub_client.publish(
            ProfileResultsMessage(
                service_id=self.service_id,
                total=0,
                completed=0,
                start_ns=self.start_time_ns,
                end_ns=self.end_time_ns,
                records=[],
                errors_by_type=[],
                was_cancelled=self.was_cancelled,
            ),
        )

publish_processing_stats() async

Publish the profile stats.

Source code in aiperf/services/records_manager/records_manager.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
async def publish_processing_stats(self) -> None:
    """Publish the profile stats."""
    await self.pub_client.publish(
        RecordsProcessingStatsMessage(
            service_id=self.service_id,
            processing_stats=PhaseProcessingStats(
                processed=self.records_count,
                errors=self.error_records_count,
                total_expected_requests=self.total_expected_requests,
            ),
            worker_stats={
                worker_id: PhaseProcessingStats(
                    processed=self.worker_success_counts[worker_id],
                    errors=self.worker_error_counts[worker_id],
                )
                for worker_id in self.worker_success_counts
            },
            request_ns=time.time_ns(),
        ),
    )

main()

Main entry point for the records manager.

Source code in aiperf/services/records_manager/records_manager.py
350
351
352
353
354
355
def main() -> None:
    """Main entry point for the records manager."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(RecordsManager)

aiperf.services.service_manager.base

BaseServiceManager

Bases: AIPerfLoggerMixin, ABC

Base class for service managers. It provides a common interface for managing services and a way to look up service information by service ID.

Source code in aiperf/services/service_manager/base.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class BaseServiceManager(AIPerfLoggerMixin, ABC):
    """
    Base class for service managers. It provides a common interface for
    managing services and a way to look up service information by service ID.
    """

    def __init__(
        self,
        required_services: dict[ServiceType, int],
        config: ServiceConfig,
    ):
        super().__init__(logger_name="service_manager")
        self.required_services = required_services
        self.config = config

        # Maps to track service information
        self.service_map: dict[ServiceType, list[ServiceRunInfo]] = {}

        # Create service ID map for component lookups
        self.service_id_map: dict[str, ServiceRunInfo] = {}

    @abstractmethod
    async def run_all_services(self) -> None:
        """Run all required services."""
        pass

    @abstractmethod
    async def shutdown_all_services(self) -> None:
        """Shutdown all managed services."""
        pass

    @abstractmethod
    async def kill_all_services(self) -> None:
        """Kill all managed services."""
        pass

    @abstractmethod
    async def wait_for_all_services_registration(
        self, stop_event: asyncio.Event, timeout_seconds: int = 30
    ) -> None:
        """Wait for all required services to be registered."""
        pass

kill_all_services() abstractmethod async

Kill all managed services.

Source code in aiperf/services/service_manager/base.py
43
44
45
46
@abstractmethod
async def kill_all_services(self) -> None:
    """Kill all managed services."""
    pass

run_all_services() abstractmethod async

Run all required services.

Source code in aiperf/services/service_manager/base.py
33
34
35
36
@abstractmethod
async def run_all_services(self) -> None:
    """Run all required services."""
    pass

shutdown_all_services() abstractmethod async

Shutdown all managed services.

Source code in aiperf/services/service_manager/base.py
38
39
40
41
@abstractmethod
async def shutdown_all_services(self) -> None:
    """Shutdown all managed services."""
    pass

wait_for_all_services_registration(stop_event, timeout_seconds=30) abstractmethod async

Wait for all required services to be registered.

Source code in aiperf/services/service_manager/base.py
48
49
50
51
52
53
@abstractmethod
async def wait_for_all_services_registration(
    self, stop_event: asyncio.Event, timeout_seconds: int = 30
) -> None:
    """Wait for all required services to be registered."""
    pass

aiperf.services.service_manager.kubernetes

KubernetesServiceManager

Bases: BaseServiceManager

Service Manager for starting and stopping services in a Kubernetes cluster.

Source code in aiperf/services/service_manager/kubernetes.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class KubernetesServiceManager(BaseServiceManager):
    """
    Service Manager for starting and stopping services in a Kubernetes cluster.
    """

    def __init__(
        self,
        required_services: dict[ServiceType, int],
        user_config: UserConfig,
        config: ServiceConfig,
    ):
        super().__init__(required_services, config)

    async def run_all_services(self) -> None:
        """Initialize all required services as Kubernetes pods."""
        self.logger.debug("Initializing all required services as Kubernetes pods")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.initialize_all_services not implemented"
        )

    async def shutdown_all_services(self) -> None:
        """Stop all required services as Kubernetes pods."""
        self.logger.debug("Stopping all required services as Kubernetes pods")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.stop_all_services not implemented"
        )

    async def kill_all_services(self) -> None:
        """Kill all required services as Kubernetes pods."""
        self.logger.debug("Killing all required services as Kubernetes pods")
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.kill_all_services not implemented"
        )

    async def wait_for_all_services_registration(
        self, stop_event: asyncio.Event, timeout_seconds: int = 30
    ) -> None:
        """Wait for all required services to be registered in Kubernetes."""
        self.logger.debug(
            "Waiting for all required services to be registered in Kubernetes"
        )
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.wait_for_all_services_registration not implemented"
        )

    async def wait_for_all_services_start(self) -> None:
        """Wait for all required services to be started in Kubernetes."""
        self.logger.debug(
            "Waiting for all required services to be started in Kubernetes"
        )
        # TODO: Implement Kubernetes
        raise NotImplementedError(
            "KubernetesServiceManager.wait_for_all_services_start not implemented"
        )

kill_all_services() async

Kill all required services as Kubernetes pods.

Source code in aiperf/services/service_manager/kubernetes.py
50
51
52
53
54
55
56
async def kill_all_services(self) -> None:
    """Kill all required services as Kubernetes pods."""
    self.logger.debug("Killing all required services as Kubernetes pods")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.kill_all_services not implemented"
    )

run_all_services() async

Initialize all required services as Kubernetes pods.

Source code in aiperf/services/service_manager/kubernetes.py
34
35
36
37
38
39
40
async def run_all_services(self) -> None:
    """Initialize all required services as Kubernetes pods."""
    self.logger.debug("Initializing all required services as Kubernetes pods")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.initialize_all_services not implemented"
    )

shutdown_all_services() async

Stop all required services as Kubernetes pods.

Source code in aiperf/services/service_manager/kubernetes.py
42
43
44
45
46
47
48
async def shutdown_all_services(self) -> None:
    """Stop all required services as Kubernetes pods."""
    self.logger.debug("Stopping all required services as Kubernetes pods")
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.stop_all_services not implemented"
    )

wait_for_all_services_registration(stop_event, timeout_seconds=30) async

Wait for all required services to be registered in Kubernetes.

Source code in aiperf/services/service_manager/kubernetes.py
58
59
60
61
62
63
64
65
66
67
68
async def wait_for_all_services_registration(
    self, stop_event: asyncio.Event, timeout_seconds: int = 30
) -> None:
    """Wait for all required services to be registered in Kubernetes."""
    self.logger.debug(
        "Waiting for all required services to be registered in Kubernetes"
    )
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.wait_for_all_services_registration not implemented"
    )

wait_for_all_services_start() async

Wait for all required services to be started in Kubernetes.

Source code in aiperf/services/service_manager/kubernetes.py
70
71
72
73
74
75
76
77
78
async def wait_for_all_services_start(self) -> None:
    """Wait for all required services to be started in Kubernetes."""
    self.logger.debug(
        "Waiting for all required services to be started in Kubernetes"
    )
    # TODO: Implement Kubernetes
    raise NotImplementedError(
        "KubernetesServiceManager.wait_for_all_services_start not implemented"
    )

ServiceKubernetesRunInfo

Bases: BaseModel

Information about a service running in a Kubernetes pod.

Source code in aiperf/services/service_manager/kubernetes.py
13
14
15
16
17
18
class ServiceKubernetesRunInfo(BaseModel):
    """Information about a service running in a Kubernetes pod."""

    pod_name: str
    node_name: str
    namespace: str

aiperf.services.service_manager.multiprocess

MultiProcessRunInfo

Bases: BaseModel

Information about a service running as a multiprocessing process.

Source code in aiperf/services/service_manager/multiprocess.py
22
23
24
25
26
27
28
29
30
31
class MultiProcessRunInfo(BaseModel):
    """Information about a service running as a multiprocessing process."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    process: Process | SpawnProcess | ForkProcess | None = Field(default=None)
    service_type: ServiceType = Field(
        ...,
        description="Type of service running in the process",
    )

MultiProcessServiceManager

Bases: BaseServiceManager

Service Manager for starting and stopping services as multiprocessing processes.

Source code in aiperf/services/service_manager/multiprocess.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class MultiProcessServiceManager(BaseServiceManager):
    """
    Service Manager for starting and stopping services as multiprocessing processes.
    """

    def __init__(
        self,
        required_services: dict[ServiceType, int],
        config: ServiceConfig,
        user_config: UserConfig | None = None,
        log_queue: "multiprocessing.Queue | None" = None,
    ):
        super().__init__(required_services, config)
        self.multi_process_info: list[MultiProcessRunInfo] = []
        self.log_queue = log_queue
        self.user_config = user_config

    async def _run_services(self, service_types: dict[ServiceType, int]) -> None:
        """Run a list of services as multiprocessing processes."""

        # Create and start all service processes
        for service_type, count in service_types.items():
            service_class = ServiceFactory.get_class_from_type(service_type)

            for _ in range(count):
                process = Process(
                    target=bootstrap_and_run_service,
                    name=f"{service_type}_process",
                    kwargs={
                        "service_class": service_class,
                        "service_id": service_type.value if count == 1 else None,
                        "service_config": self.config,
                        "user_config": self.user_config,
                        "log_queue": self.log_queue,
                    },
                    daemon=True,
                )
                if service_type in [
                    ServiceType.WORKER_MANAGER,
                ]:
                    process.daemon = False  # Worker manager cannot be a daemon because it needs to be able to spawn worker processes

                process.start()

                self.debug(
                    lambda pid=process.pid,
                    type=service_type: f"Service {type} started as process (pid: {pid})"
                )

                self.multi_process_info.append(
                    MultiProcessRunInfo(process=process, service_type=service_type)
                )

    async def run_all_services(self) -> None:
        """Start all required services as multiprocessing processes."""
        self.logger.debug("Starting all required services as multiprocessing processes")

        try:
            await self._run_services(self.required_services)
        except Exception as e:
            self.logger.error("Error starting services: %s", e)
            raise e

    async def shutdown_all_services(self) -> None:
        """Stop all required services as multiprocessing processes."""
        self.logger.debug("Stopping all service processes")

        # Wait for all to finish in parallel
        await asyncio.gather(
            *[self._wait_for_process(info) for info in self.multi_process_info]
        )

    async def kill_all_services(self) -> None:
        """Kill all required services as multiprocessing processes."""
        self.logger.debug("Killing all service processes")

        # Kill all processes
        for info in self.multi_process_info:
            if info.process:
                info.process.kill()

        # Wait for all to finish in parallel
        await asyncio.gather(
            *[self._wait_for_process(info) for info in self.multi_process_info]
        )

    async def wait_for_all_services_registration(
        self, stop_event: asyncio.Event, timeout_seconds: int = 30
    ) -> None:
        """Wait for all required services to be registered.

        Args:
            stop_event: Event to check if operation should be cancelled
            timeout_seconds: Maximum time to wait in seconds

        Raises:
            Exception if any service failed to register, None otherwise
        """
        self.logger.debug("Waiting for all required services to register...")

        # Get the set of required service types for checking completion
        required_types = set(self.required_services)

        # TODO: Can this be done better by using asyncio.Event()?

        async def _wait_for_registration():
            required_types_set = set(typ for typ, _ in required_types)

            while not stop_event.is_set():
                # Get all registered service types from the id map
                registered_types = {
                    service_info.service_type
                    for service_info in self.service_id_map.values()
                    if service_info.registration_status
                    == ServiceRegistrationStatus.REGISTERED
                }

                # Check if all required types are registered
                if required_types_set.issubset(registered_types):
                    return

                # Wait a bit before checking again
                await asyncio.sleep(0.5)

        try:
            await asyncio.wait_for(_wait_for_registration(), timeout=timeout_seconds)
        except asyncio.TimeoutError as e:
            # Log which services didn't register in time
            registered_types_set = set(
                service_info.service_type
                for service_info in self.service_id_map.values()
                if service_info.registration_status
                == ServiceRegistrationStatus.REGISTERED
            )

            for service_type, _ in required_types:
                if service_type not in registered_types_set:
                    self.logger.error(
                        f"Service {service_type} failed to register within timeout"
                    )

            raise ServiceError(
                "Some services failed to register within timeout",
                ServiceType.SYSTEM_CONTROLLER,
                "system_controller",  # TODO: Get the service ID from the system controller
            ) from e

    async def _wait_for_process(self, info: MultiProcessRunInfo) -> None:
        """Wait for a process to terminate with timeout handling."""
        if not info.process or not info.process.is_alive():
            return

        try:
            info.process.terminate()
            await asyncio.wait_for(
                asyncio.to_thread(
                    info.process.join, timeout=TASK_CANCEL_TIMEOUT_SHORT
                ),  # Add timeout to join
                timeout=GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,  # Overall timeout
            )
            self.logger.debug(
                "Service %s process stopped (pid: %d)",
                info.service_type,
                info.process.pid,
            )
        except asyncio.TimeoutError:
            self.logger.warning(
                "Service %s process (pid: %d) did not terminate gracefully, killing",
                info.service_type,
                info.process.pid,
            )
            info.process.kill()

kill_all_services() async

Kill all required services as multiprocessing processes.

Source code in aiperf/services/service_manager/multiprocess.py
106
107
108
109
110
111
112
113
114
115
116
117
118
async def kill_all_services(self) -> None:
    """Kill all required services as multiprocessing processes."""
    self.logger.debug("Killing all service processes")

    # Kill all processes
    for info in self.multi_process_info:
        if info.process:
            info.process.kill()

    # Wait for all to finish in parallel
    await asyncio.gather(
        *[self._wait_for_process(info) for info in self.multi_process_info]
    )

run_all_services() async

Start all required services as multiprocessing processes.

Source code in aiperf/services/service_manager/multiprocess.py
87
88
89
90
91
92
93
94
95
async def run_all_services(self) -> None:
    """Start all required services as multiprocessing processes."""
    self.logger.debug("Starting all required services as multiprocessing processes")

    try:
        await self._run_services(self.required_services)
    except Exception as e:
        self.logger.error("Error starting services: %s", e)
        raise e

shutdown_all_services() async

Stop all required services as multiprocessing processes.

Source code in aiperf/services/service_manager/multiprocess.py
 97
 98
 99
100
101
102
103
104
async def shutdown_all_services(self) -> None:
    """Stop all required services as multiprocessing processes."""
    self.logger.debug("Stopping all service processes")

    # Wait for all to finish in parallel
    await asyncio.gather(
        *[self._wait_for_process(info) for info in self.multi_process_info]
    )

wait_for_all_services_registration(stop_event, timeout_seconds=30) async

Wait for all required services to be registered.

Parameters:

Name Type Description Default
stop_event Event

Event to check if operation should be cancelled

required
timeout_seconds int

Maximum time to wait in seconds

30
Source code in aiperf/services/service_manager/multiprocess.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
async def wait_for_all_services_registration(
    self, stop_event: asyncio.Event, timeout_seconds: int = 30
) -> None:
    """Wait for all required services to be registered.

    Args:
        stop_event: Event to check if operation should be cancelled
        timeout_seconds: Maximum time to wait in seconds

    Raises:
        Exception if any service failed to register, None otherwise
    """
    self.logger.debug("Waiting for all required services to register...")

    # Get the set of required service types for checking completion
    required_types = set(self.required_services)

    # TODO: Can this be done better by using asyncio.Event()?

    async def _wait_for_registration():
        required_types_set = set(typ for typ, _ in required_types)

        while not stop_event.is_set():
            # Get all registered service types from the id map
            registered_types = {
                service_info.service_type
                for service_info in self.service_id_map.values()
                if service_info.registration_status
                == ServiceRegistrationStatus.REGISTERED
            }

            # Check if all required types are registered
            if required_types_set.issubset(registered_types):
                return

            # Wait a bit before checking again
            await asyncio.sleep(0.5)

    try:
        await asyncio.wait_for(_wait_for_registration(), timeout=timeout_seconds)
    except asyncio.TimeoutError as e:
        # Log which services didn't register in time
        registered_types_set = set(
            service_info.service_type
            for service_info in self.service_id_map.values()
            if service_info.registration_status
            == ServiceRegistrationStatus.REGISTERED
        )

        for service_type, _ in required_types:
            if service_type not in registered_types_set:
                self.logger.error(
                    f"Service {service_type} failed to register within timeout"
                )

        raise ServiceError(
            "Some services failed to register within timeout",
            ServiceType.SYSTEM_CONTROLLER,
            "system_controller",  # TODO: Get the service ID from the system controller
        ) from e

aiperf.services.system_controller.system_controller

SystemController

Bases: SignalHandlerMixin, BaseControllerService

System Controller service.

This service is responsible for managing the lifecycle of all other services. It will start, stop, and configure all other services.

Source code in aiperf/services/system_controller/system_controller.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
@ServiceFactory.register(ServiceType.SYSTEM_CONTROLLER)
class SystemController(SignalHandlerMixin, BaseControllerService):
    """System Controller service.

    This service is responsible for managing the lifecycle of all other services.
    It will start, stop, and configure all other services.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig,
        service_id: str | None = None,
    ) -> None:
        super().__init__(service_config=service_config, service_id=service_id)
        self.logger.debug("Creating System Controller")

        self._system_state: SystemState = SystemState.INITIALIZING
        self.user_config = user_config

        # List of required service types, in no particular order
        # These are services that must be running before the system controller can start profiling
        self.required_services = {
            ServiceType.DATASET_MANAGER: 1,
            ServiceType.TIMING_MANAGER: 1,
            ServiceType.WORKER_MANAGER: 1,
            ServiceType.RECORDS_MANAGER: 1,
            ServiceType.INFERENCE_RESULT_PARSER: service_config.result_parser_service_count,
        }

        self.service_manager: BaseServiceManager = None  # type: ignore - is set in _initialize

        self.event_bus_proxy: BaseZMQProxy | None = None
        self.event_bus_proxy_task: asyncio.Task | None = None

        self.dataset_manager_proxy: BaseZMQProxy | None = None
        self.dataset_manager_proxy_task: asyncio.Task | None = None

        self.raw_inference_proxy: BaseZMQProxy | None = None
        self.raw_inference_proxy_task: asyncio.Task | None = None

        self.logger.debug("System Controller created")

    @property
    def service_type(self) -> ServiceType:
        """The type of service."""
        return ServiceType.SYSTEM_CONTROLLER

    async def initialize(self) -> None:
        """Override the base initialize method to add pre-initialization and
        post-initialization steps. This allows us to run the UI and progress
        logger before the system is fully initialized.
        """
        await self._pre_initialize()
        await BaseService.initialize(self)
        await self._post_initialize()

    async def _pre_initialize(self) -> None:
        """Initialize system controller-specific components.

        This method will:
        - Initialize the service manager
        - Subscribe to relevant messages
        """
        self.logger.debug("Initializing System Controller")

        self.setup_signal_handlers(self._handle_signal)
        self.logger.debug("Setup signal handlers")

        self.zmq_context = zmq.asyncio.Context.instance()

        self.event_bus_proxy = ZMQProxyFactory.create_instance(
            ZMQProxyType.XPUB_XSUB,
            context=self.zmq_context,
            zmq_proxy_config=self.service_config.comm_config.event_bus_proxy_config,
        )
        self.event_bus_proxy_task = asyncio.create_task(self.event_bus_proxy.run())

        self.dataset_manager_proxy = ZMQProxyFactory.create_instance(
            ZMQProxyType.DEALER_ROUTER,
            context=self.zmq_context,
            zmq_proxy_config=self.service_config.comm_config.dataset_manager_proxy_config,
        )
        self.dataset_manager_proxy_task = asyncio.create_task(
            self.dataset_manager_proxy.run()
        )

        self.raw_inference_proxy = ZMQProxyFactory.create_instance(
            ZMQProxyType.PUSH_PULL,
            context=self.zmq_context,
            zmq_proxy_config=self.service_config.comm_config.raw_inference_proxy_config,
        )
        self.raw_inference_proxy_task = asyncio.create_task(
            self.raw_inference_proxy.run()
        )

    async def _post_initialize(self) -> None:
        """Post-initialize the system controller."""

        if self.service_config.service_run_type == ServiceRunType.MULTIPROCESSING:
            self.service_manager = MultiProcessServiceManager(
                required_services=self.required_services,
                user_config=self.user_config,
                config=self.service_config,
            )

        elif self.service_config.service_run_type == ServiceRunType.KUBERNETES:
            self.service_manager = KubernetesServiceManager(
                required_services=self.required_services,
                user_config=self.user_config,
                config=self.service_config,
            )

        else:
            raise self._service_error(
                f"Unsupported service run type: {self.service_config.service_run_type}",
            )

        # Subscribe to relevant messages
        subscribe_callbacks = [
            (MessageType.REGISTRATION, self._process_registration_message),
            (MessageType.HEARTBEAT, self._process_heartbeat_message),
            (MessageType.STATUS, self._process_status_message),
            (MessageType.CREDITS_COMPLETE, self._process_credits_complete_message),
            (MessageType.NOTIFICATION, self._process_notification_message),
            (MessageType.COMMAND_RESPONSE, self._process_command_response_message),
        ]
        for message_type, callback in subscribe_callbacks:
            try:
                await self.sub_client.subscribe(
                    message_type=message_type, callback=callback
                )
            except Exception as e:
                self.logger.error(
                    "Failed to subscribe to message_type %s: %s", message_type, e
                )
                raise CommunicationError(
                    f"Failed to subscribe to message_type {message_type}: {e}",
                ) from e

        # TODO: HACK: Wait for 1 second to ensure subscriptions are set up
        await asyncio.sleep(1)

        self._system_state = SystemState.CONFIGURING
        await self._bootstrap_system()

    async def _handle_signal(self, sig: int) -> None:
        """Handle received signals by triggering graceful shutdown.

        Args:
            sig: The signal number received
        """
        self.logger.debug("Received signal %s, initiating graceful shutdown", sig)
        if sig == signal.SIGINT or sig == signal.SIGTERM:
            self.stop_event.set()
            return

        if self.pub_client.is_shutdown:
            self.logger.error("Pub client is shutdown, killing all services")
            await self.kill()
            return

        await self.send_command_to_service(
            target_service_id=None,
            target_service_type=ServiceType.RECORDS_MANAGER,
            command=CommandType.PROCESS_RECORDS,
            data=ProcessRecordsCommandData(cancelled=True),
        )

        self.stop_event.set()

    async def _bootstrap_system(self) -> None:
        """Bootstrap the system services.

        This method will:
        - Initialize all required services
        - Wait for all required services to be registered
        - Start all required services
        """
        self.debug("System Controller is bootstrapping services")

        # Start all required services
        try:
            await self.service_manager.run_all_services()
        except Exception as e:
            raise self._service_error(
                "Failed to initialize all services",
            ) from e

        # TODO: HACK: Wait for 1 second to ensure registrations made. This needs to be
        # removed once we have the ability to track registrations of services and their state before
        # starting the profiling.
        await asyncio.sleep(1)

        self.info("AIPerf System is READY")
        self._system_state = SystemState.READY

        await self.start_profiling_all_services()

        if self.stop_event.is_set():
            self.debug("System Controller stopped before all services started")
            return  # Don't continue with the rest of the initialization

        self.debug("All required services started successfully")
        self.info("AIPerf System is RUNNING")

    @on_stop
    async def _stop(self) -> None:
        """Stop the system controller and all running services.

        This method will:
        - Stop all running services
        """
        self.debug("Stopping System Controller")
        self.info("AIPerf System is EXITING")
        # logging.root.setLevel(logging.DEBUG)

        self._system_state = SystemState.STOPPING

        # TODO: This is a hack to force printing results again
        # Process records command
        await self.send_command_to_service(
            target_service_id=None,
            target_service_type=ServiceType.RECORDS_MANAGER,
            command=CommandType.PROCESS_RECORDS,
            data=ProcessRecordsCommandData(cancelled=False),
        )

        # Broadcast a stop command to all services
        await self.send_command_to_service(
            target_service_id=None,
            command=CommandType.SHUTDOWN,
        )

        try:
            await self.service_manager.shutdown_all_services()
        except Exception as e:
            raise self._service_error(
                "Failed to stop all services",
            ) from e

        tasks = []
        if self.event_bus_proxy_task:
            await self.event_bus_proxy.stop()
            self.event_bus_proxy_task.cancel()
            tasks.append(self.event_bus_proxy_task)

        if self.dataset_manager_proxy_task:
            await self.dataset_manager_proxy.stop()
            self.dataset_manager_proxy_task.cancel()
            tasks.append(self.dataset_manager_proxy_task)

        if self.raw_inference_proxy_task:
            await self.raw_inference_proxy.stop()
            self.raw_inference_proxy_task.cancel()
            tasks.append(self.raw_inference_proxy_task)

        await asyncio.wait_for(
            asyncio.gather(*tasks),
            timeout=TASK_CANCEL_TIMEOUT_SHORT,
        )

        # TODO: This is a hack to give the services time to produce results
        # await asyncio.sleep(3)

    @on_cleanup
    async def _cleanup(self) -> None:
        """Clean up system controller-specific components."""
        self.debug("Cleaning up System Controller")

        await self.kill()

        self._system_state = SystemState.SHUTDOWN

    async def start_profiling_all_services(self) -> None:
        """Tell all services to start profiling."""
        # TODO: HACK: Wait for 1 second to ensure services are ready
        await asyncio.sleep(1)

        self.debug("Sending PROFILE_START command to all services")
        await self.send_command_to_service(
            target_service_id=None,
            command=CommandType.PROFILE_START,
        )

    async def _process_registration_message(self, message: RegistrationMessage) -> None:
        """Process a registration message from a service. It will
        add the service to the service manager and send a configure command
        to the service.

        Args:
            message: The registration message to process
        """
        service_id = message.service_id
        service_type = message.service_type

        self.logger.info(
            "Processing registration from %s with ID: %s", service_type, service_id
        )

        service_info = ServiceRunInfo(
            registration_status=ServiceRegistrationStatus.REGISTERED,
            service_type=service_type,
            service_id=service_id,
            first_seen=time.time_ns(),
            state=ServiceState.READY,
            last_seen=time.time_ns(),
        )

        self.service_manager.service_id_map[service_id] = service_info
        if service_type not in self.service_manager.service_map:
            self.service_manager.service_map[service_type] = []
        self.service_manager.service_map[service_type].append(service_info)

        is_required = service_type in self.required_services
        self.logger.info(
            "Registered %s service: %s with ID: %s",
            "required" if is_required else "non-required",
            service_type,
            service_id,
        )

        # Send configure command to the newly registered service
        try:
            await self.send_command_to_service(
                target_service_id=service_id,
                command=CommandType.PROFILE_CONFIGURE,
                data=self.user_config,
            )
        except Exception as e:
            raise self._service_error(
                f"Failed to send configure command to {service_type} (ID: {service_id})",
            ) from e

        self.logger.debug(
            "Sent configure command to %s (ID: %s)", service_type, service_id
        )

    async def _process_heartbeat_message(self, message: HeartbeatMessage) -> None:
        """Process a heartbeat message from a service. It will
        update the last seen timestamp and state of the service.

        Args:
            message: The heartbeat message to process
        """
        service_id = message.service_id
        service_type = message.service_type
        timestamp = message.request_ns

        self.logger.debug(
            "Received heartbeat from %s (ID: %s)", service_type, service_id
        )

        # Update the last heartbeat timestamp if the component exists
        try:
            service_info = self.service_manager.service_id_map[service_id]
            service_info.last_seen = timestamp
            service_info.state = message.state
            self.logger.debug("Updated heartbeat for %s to %s", service_id, timestamp)
        except Exception:
            self.logger.warning(
                f"Received heartbeat from unknown service: {service_id} ({service_type})"
            )

    async def _process_credits_complete_message(
        self, message: CreditsCompleteMessage
    ) -> None:
        """Process a credits complete message from a service. It will
        update the state of the service with the service manager.

        Args:
            message: The credits complete message to process
        """
        service_id = message.service_id
        self.logger.info("Received credits complete from %s", service_id)

    async def _process_status_message(self, message: StatusMessage) -> None:
        """Process a status message from a service. It will
        update the state of the service with the service manager.

        Args:
            message: The status message to process
        """
        service_id = message.service_id
        service_type = message.service_type
        state = message.state

        self.logger.debug(
            f"Received status update from {service_type} (ID: {service_id}): {state}"
        )

        # Update the component state if the component exists
        if service_id not in self.service_manager.service_id_map:
            self.logger.debug(
                "Received status update from un-registered service: %s (%s)",
                service_id,
                service_type,
            )
            return

        service_info = self.service_manager.service_id_map.get(service_id)
        if service_info is None:
            return

        service_info.state = message.state

        self.logger.debug(f"Updated state for {service_id} to {state}")

    async def _process_notification_message(self, message: NotificationMessage) -> None:
        """Process a notification message."""
        self.logger.info("SC: Received notification message: %s", message)

    async def _process_command_response_message(
        self, message: CommandResponseMessage
    ) -> None:
        """Process a command response message."""
        self.logger.debug("SC: Received command response message: %s", message)
        if message.status == CommandResponseStatus.SUCCESS:
            self.logger.debug(
                "SC: Command %s succeeded with data: %s", message.command, message.data
            )
        else:
            self.logger.error(
                "SC: Command %s failed: %s", message.command, message.error
            )
            if message.error:
                self.logger.error("SC: Error details: %s", message.error)

    async def send_command_to_service(
        self,
        target_service_id: str | None,
        command: CommandType,
        data: Any | None = None,
        target_service_type: ServiceType | None = None,
    ) -> None:
        """Send a command to a specific service.

        Args:
            target_service_id: ID of the target service, or None to send to all services
            target_service_type: Type of the target service, or None to send to all services
            command: The command to send (from CommandType enum).
            data: Optional data to send with the command.

        Raises:
            CommunicationError: If the communication is not initialized
                or the command was not sent successfully
        """
        if not self.comms:
            self.logger.error("Cannot send command: Communication is not initialized")
            raise NotInitializedError(
                "Communication channels are not initialized",
            )

        # Publish command message
        try:
            await self.pub_client.publish(
                self.create_command_message(
                    command=command,
                    target_service_id=target_service_id,
                    target_service_type=target_service_type,
                    data=data,
                )
            )
        except Exception as e:
            self.logger.error("Exception publishing command: %s", e)
            raise CommunicationError(f"Failed to publish command: {e}") from e

    async def kill(self):
        """Kill the system controller."""
        try:
            await self.service_manager.kill_all_services()
        except Exception as e:
            raise self._service_error("Failed to stop all services") from e

service_type property

The type of service.

initialize() async

Override the base initialize method to add pre-initialization and post-initialization steps. This allows us to run the UI and progress logger before the system is fully initialized.

Source code in aiperf/services/system_controller/system_controller.py
 97
 98
 99
100
101
102
103
104
async def initialize(self) -> None:
    """Override the base initialize method to add pre-initialization and
    post-initialization steps. This allows us to run the UI and progress
    logger before the system is fully initialized.
    """
    await self._pre_initialize()
    await BaseService.initialize(self)
    await self._post_initialize()

kill() async

Kill the system controller.

Source code in aiperf/services/system_controller/system_controller.py
516
517
518
519
520
521
async def kill(self):
    """Kill the system controller."""
    try:
        await self.service_manager.kill_all_services()
    except Exception as e:
        raise self._service_error("Failed to stop all services") from e

send_command_to_service(target_service_id, command, data=None, target_service_type=None) async

Send a command to a specific service.

Parameters:

Name Type Description Default
target_service_id str | None

ID of the target service, or None to send to all services

required
target_service_type ServiceType | None

Type of the target service, or None to send to all services

None
command CommandType

The command to send (from CommandType enum).

required
data Any | None

Optional data to send with the command.

None

Raises:

Type Description
CommunicationError

If the communication is not initialized or the command was not sent successfully

Source code in aiperf/services/system_controller/system_controller.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
async def send_command_to_service(
    self,
    target_service_id: str | None,
    command: CommandType,
    data: Any | None = None,
    target_service_type: ServiceType | None = None,
) -> None:
    """Send a command to a specific service.

    Args:
        target_service_id: ID of the target service, or None to send to all services
        target_service_type: Type of the target service, or None to send to all services
        command: The command to send (from CommandType enum).
        data: Optional data to send with the command.

    Raises:
        CommunicationError: If the communication is not initialized
            or the command was not sent successfully
    """
    if not self.comms:
        self.logger.error("Cannot send command: Communication is not initialized")
        raise NotInitializedError(
            "Communication channels are not initialized",
        )

    # Publish command message
    try:
        await self.pub_client.publish(
            self.create_command_message(
                command=command,
                target_service_id=target_service_id,
                target_service_type=target_service_type,
                data=data,
            )
        )
    except Exception as e:
        self.logger.error("Exception publishing command: %s", e)
        raise CommunicationError(f"Failed to publish command: {e}") from e

start_profiling_all_services() async

Tell all services to start profiling.

Source code in aiperf/services/system_controller/system_controller.py
323
324
325
326
327
328
329
330
331
332
async def start_profiling_all_services(self) -> None:
    """Tell all services to start profiling."""
    # TODO: HACK: Wait for 1 second to ensure services are ready
    await asyncio.sleep(1)

    self.debug("Sending PROFILE_START command to all services")
    await self.send_command_to_service(
        target_service_id=None,
        command=CommandType.PROFILE_START,
    )

main()

Main entry point for the system controller.

Source code in aiperf/services/system_controller/system_controller.py
524
525
526
527
528
529
def main() -> None:
    """Main entry point for the system controller."""

    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(SystemController)

aiperf.services.system_controller.system_mixins

SignalHandlerMixin

Mixin for services that need to handle system signals.

Source code in aiperf/services/system_controller/system_mixins.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SignalHandlerMixin:
    """Mixin for services that need to handle system signals."""

    def __init__(self, *args, **kwargs) -> None:
        # Set to store signal handler tasks to prevent them from being garbage collected
        self._signal_tasks = set()
        self.logger = logging.getLogger(__name__)
        super().__init__(*args, **kwargs)

    def setup_signal_handlers(
        self, callback: Callable[[int], Coroutine[Any, Any, None]]
    ) -> None:
        """This method will set up signal handlers for the SIGTERM and SIGINT signals
        in order to trigger a graceful shutdown of the service.

        Args:
            callback: The callback to call when a signal is received
        """
        loop = asyncio.get_running_loop()

        def signal_handler(sig: int) -> None:
            # Create a task and store it so it doesn't get garbage collected
            task = asyncio.create_task(callback(sig))

            # Store the task somewhere to prevent it from being garbage collected
            # before it completes
            self._signal_tasks.add(task)
            task.add_done_callback(self._signal_tasks.discard)

        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))

setup_signal_handlers(callback)

This method will set up signal handlers for the SIGTERM and SIGINT signals in order to trigger a graceful shutdown of the service.

Parameters:

Name Type Description Default
callback Callable[[int], Coroutine[Any, Any, None]]

The callback to call when a signal is received

required
Source code in aiperf/services/system_controller/system_mixins.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def setup_signal_handlers(
    self, callback: Callable[[int], Coroutine[Any, Any, None]]
) -> None:
    """This method will set up signal handlers for the SIGTERM and SIGINT signals
    in order to trigger a graceful shutdown of the service.

    Args:
        callback: The callback to call when a signal is received
    """
    loop = asyncio.get_running_loop()

    def signal_handler(sig: int) -> None:
        # Create a task and store it so it doesn't get garbage collected
        task = asyncio.create_task(callback(sig))

        # Store the task somewhere to prevent it from being garbage collected
        # before it completes
        self._signal_tasks.add(task)
        task.add_done_callback(self._signal_tasks.discard)

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))

aiperf.services.timing_manager.concurrency_strategy

ConcurrencyStrategy

Bases: CreditIssuingStrategy, AsyncTaskManagerMixin, AIPerfLoggerMixin

Class for concurrency credit issuing strategy.

Source code in aiperf/services/timing_manager/concurrency_strategy.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@CreditIssuingStrategyFactory.register(TimingMode.CONCURRENCY)
class ConcurrencyStrategy(
    CreditIssuingStrategy, AsyncTaskManagerMixin, AIPerfLoggerMixin
):
    """Class for concurrency credit issuing strategy."""

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        # If the concurrency is larger than the total number of requests, it does not matter
        # as it is simply an upper bound that will never be reached
        self._semaphore = asyncio.Semaphore(value=config.concurrency)

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single credit phase. This will not return until the phase sending is complete."""
        if phase_stats.is_time_based:
            await self._execute_time_based_phase(phase_stats)
        elif phase_stats.is_request_count_based:
            await self._execute_request_count_based_phase(phase_stats)
        else:
            raise InvalidStateError(
                "Phase must have either a valid total or expected_duration_ns set"
            )

    async def _execute_time_based_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a time-based phase."""

        # Start the internal loop in a task so that we can cancel it when the time expires
        time_task = asyncio.create_task(
            self._execute_time_based_phase_internal(phase_stats)
        )

        # Calculate how long until the phase expires
        sleep_time_sec = (
            (phase_stats.start_ns / NANOS_PER_SECOND)  # type: ignore
            + phase_stats.expected_duration_sec
            - time.time()
        )
        self.trace(
            lambda: f"Time-based phase will expire in {sleep_time_sec} seconds: {phase_stats}"
        )

        # Sleep until the phase expires, and then cancel the task
        await asyncio.sleep(sleep_time_sec)
        time_task.cancel()
        self.debug(lambda: f"Time-based phase execution expired: {phase_stats}")
        # Note, not awaiting the task here as we do not want to block moving to the next phase

    async def _execute_time_based_phase_internal(
        self, phase_stats: CreditPhaseStats
    ) -> None:
        """Execute a the internal loop for a time-based phase. This will be called within a task and cancelled when the time expires."""

        self.trace(
            lambda: f"_execute_time_based_phase_internal loop entered: {phase_stats}"
        )

        # This will loop until the task is cancelled
        while True:
            try:
                # Acquire the semaphore. Once we hit the concurrency limit, this will block until a credit is returned
                await self._semaphore.acquire()
                self.execute_async(
                    self.credit_manager.drop_credit(
                        credit_phase=phase_stats.type,
                    )
                )
                phase_stats.sent += 1
            except asyncio.CancelledError:
                self.trace(
                    lambda: f"_execute_time_based_phase_internal loop exited: {phase_stats}"
                )
                self.debug("Time-based phase execution expired")
                break

    async def _execute_request_count_based_phase(
        self, phase_stats: CreditPhaseStats
    ) -> None:
        self.trace(
            lambda: f"_execute_request_count_based_phase loop entered: {phase_stats}"
        )

        total: int = phase_stats.total_expected_requests  # type: ignore

        while phase_stats.sent < total:
            await self._semaphore.acquire()
            self.execute_async(
                self.credit_manager.drop_credit(
                    credit_phase=phase_stats.type,
                )
            )
            phase_stats.sent += 1

        self.trace(
            lambda: f"_execute_request_count_based_phase loop exited: {phase_stats}"
        )

    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """Process a credit return message."""

        # Release the semaphore to allow another credit to be issued,
        # then call the superclass to handle the credit return like normal
        self._semaphore.release()
        self.trace(lambda: f"Credit return released semaphore: {self._semaphore}")
        await super()._on_credit_return(message)

aiperf.services.timing_manager.config

TimingManagerConfig

Bases: AIPerfBaseModel

Configuration for the timing manager.

Source code in aiperf/services/timing_manager/config.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class TimingManagerConfig(AIPerfBaseModel):
    """Configuration for the timing manager."""

    timing_mode: TimingMode = LoadGeneratorDefaults.TIMING_MODE
    concurrency: int = LoadGeneratorDefaults.CONCURRENCY
    request_rate: float | None = LoadGeneratorDefaults.REQUEST_RATE
    request_rate_mode: RequestRateMode = LoadGeneratorDefaults.REQUEST_RATE_MODE
    request_count: int = LoadGeneratorDefaults.REQUEST_COUNT
    warmup_request_count: int = LoadGeneratorDefaults.WARMUP_REQUEST_COUNT
    random_seed: int | None = None
    progress_report_interval_sec: float = (
        ServiceDefaults.PROGRESS_REPORT_INTERVAL_SECONDS
    )

    @classmethod
    def from_user_config(cls, user_config: UserConfig) -> "TimingManagerConfig":
        """Create a TimingManagerConfig from a UserConfig."""

        if user_config.input.fixed_schedule:
            timing_mode = TimingMode.FIXED_SCHEDULE
        elif user_config.loadgen.request_rate is not None:
            timing_mode = TimingMode.REQUEST_RATE
        else:
            # Default to concurrency mode if no request rate or schedule is provided
            timing_mode = TimingMode.CONCURRENCY

        return cls(
            timing_mode=timing_mode,
            concurrency=user_config.loadgen.concurrency,
            request_rate=user_config.loadgen.request_rate,
            request_rate_mode=user_config.loadgen.request_rate_mode,
            request_count=user_config.loadgen.request_count,
            warmup_request_count=user_config.loadgen.warmup_request_count,
            random_seed=user_config.input.random_seed,
        )

from_user_config(user_config) classmethod

Create a TimingManagerConfig from a UserConfig.

Source code in aiperf/services/timing_manager/config.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@classmethod
def from_user_config(cls, user_config: UserConfig) -> "TimingManagerConfig":
    """Create a TimingManagerConfig from a UserConfig."""

    if user_config.input.fixed_schedule:
        timing_mode = TimingMode.FIXED_SCHEDULE
    elif user_config.loadgen.request_rate is not None:
        timing_mode = TimingMode.REQUEST_RATE
    else:
        # Default to concurrency mode if no request rate or schedule is provided
        timing_mode = TimingMode.CONCURRENCY

    return cls(
        timing_mode=timing_mode,
        concurrency=user_config.loadgen.concurrency,
        request_rate=user_config.loadgen.request_rate,
        request_rate_mode=user_config.loadgen.request_rate_mode,
        request_count=user_config.loadgen.request_count,
        warmup_request_count=user_config.loadgen.warmup_request_count,
        random_seed=user_config.input.random_seed,
    )

aiperf.services.timing_manager.credit_issuing_strategy

CreditIssuingStrategy

Bases: AsyncTaskManagerMixin, AIPerfLoggerMixin, ABC

Base class for credit issuing strategies.

Source code in aiperf/services/timing_manager/credit_issuing_strategy.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class CreditIssuingStrategy(AsyncTaskManagerMixin, AIPerfLoggerMixin, ABC):
    """
    Base class for credit issuing strategies.
    """

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__()
        self.config = config
        self.credit_manager = credit_manager

        # This event is set when all phases are complete
        self.all_phases_complete_event = asyncio.Event()

        # The running stats for each phase, keyed by phase type.
        self.phase_stats: dict[CreditPhase, CreditPhaseStats] = {}

        # The phases to run including their configuration, in order of execution.
        self.ordered_phase_configs: list[CreditPhaseConfig] = []

        self._setup_phase_configs()
        self._validate_phase_configs()

    def _setup_phase_configs(self) -> None:
        """Setup the phases for the strategy. This can be overridden in subclasses to modify the phases."""
        self._setup_warmup_phase_config()
        self._setup_profiling_phase_config()
        self.info(
            lambda: f"Credit issuing strategy {self.__class__.__name__} initialized with {len(self.ordered_phase_configs)} "
            f"phase(s): {self.ordered_phase_configs}"
        )

    def _setup_warmup_phase_config(self) -> None:
        """Setup the warmup phase. This can be overridden in subclasses to modify the warmup phase."""
        if self.config.warmup_request_count > 0:
            self.ordered_phase_configs.append(
                CreditPhaseConfig(
                    type=CreditPhase.WARMUP,
                    total_expected_requests=self.config.warmup_request_count,
                )
            )

    def _setup_profiling_phase_config(self) -> None:
        """Setup the profiling phase. This can be overridden in subclasses to modify the profiling phase."""
        self.ordered_phase_configs.append(
            CreditPhaseConfig(
                type=CreditPhase.PROFILING,
                total_expected_requests=self.config.request_count,
            )
        )

    def _validate_phase_configs(self) -> None:
        """Validate the phase configs."""
        for phase_config in self.ordered_phase_configs:
            if not phase_config.is_valid:
                raise ConfigurationError(
                    f"Phase {phase_config.type} is not valid. It must have either a valid total_expected_requests or expected_duration_sec set"
                )

    async def start(self) -> None:
        """Start the credit issuing strategy. This will launch the progress reporting loop, the
        warmup phase (if applicable), and the profiling phase, all in the background."""
        self.debug(
            lambda: f"Starting credit issuing strategy {self.__class__.__name__}"
        )
        self.all_phases_complete_event.clear()

        # Start the progress reporting loop in the background
        self.execute_async(self._progress_report_loop())

        # Execute the phases in the background
        self.execute_async(self._execute_phases())

        self.debug(
            lambda: f"Waiting for all credit phases to complete for {self.__class__.__name__}"
        )
        # Wait for all phases to complete before returning
        await self.all_phases_complete_event.wait()
        self.debug(lambda: f"All credit phases completed for {self.__class__.__name__}")

    async def _execute_phases(self) -> None:
        """Execute the all of the credit phases sequentially. This can be overridden in subclasses to modify the execution of the phases."""
        for phase_config in self.ordered_phase_configs:
            phase_stats = CreditPhaseStats.from_phase_config(phase_config)
            phase_stats.start_ns = time.time_ns()
            self.phase_stats[phase_config.type] = phase_stats

            self.execute_async(
                self.credit_manager.publish_phase_start(
                    phase_config.type,
                    phase_stats.start_ns,
                    # Only one of the below will be set, this is already validated in the strategy
                    phase_config.total_expected_requests,
                    phase_config.expected_duration_sec,
                )
            )

            # This is implemented in subclasses
            await self._execute_single_phase(phase_stats)

            # We have sent all the credits for this phase. We must continue to the next
            # phase even though not all the credits have been returned. This is because
            # we do not want a gap in the credit issuing.
            phase_stats.sent_end_ns = time.time_ns()
            self.execute_async(
                self.credit_manager.publish_phase_sending_complete(
                    phase_config.type, phase_stats.sent_end_ns
                )
            )

    @abstractmethod
    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single phase. Should not return until the phase sending is complete. Must be implemented in subclasses."""
        raise NotImplementedError("Subclasses must implement this method")

    async def stop(self) -> None:
        """Stop the credit issuing strategy."""
        await self.cancel_all_tasks()

    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """This is called by the credit manager when a credit is returned. It can be
        overridden in subclasses to handle the credit return."""
        if message.phase not in self.phase_stats:
            # self.warning(
            #     lambda: f"Credit return message received for phase {message.phase} but no phase stats found"
            # )
            return

        phase_stats = self.phase_stats[message.phase]
        phase_stats.completed += 1

        if (
            # If we have sent all the credits, check if this is the last one to be returned
            phase_stats.is_sending_complete
            and phase_stats.completed >= phase_stats.total_expected_requests  # type: ignore[operator]
        ):
            phase_stats.end_ns = time.time_ns()
            self.info(lambda: f"Phase completed: {phase_stats}")

            self.execute_async(
                self.credit_manager.publish_phase_complete(
                    message.phase, phase_stats.completed, phase_stats.end_ns
                )
            )

            if phase_stats.type == CreditPhase.PROFILING:
                self.execute_async(self.credit_manager.publish_credits_complete())
                self.all_phases_complete_event.set()

            # We don't need to keep track of the phase stats anymore
            self.notice(
                lambda: f"Phase {message.phase} completed, removing phase stats"
            )
            self.phase_stats.pop(message.phase)

    async def _progress_report_loop(self) -> None:
        """Report the progress at a fixed interval."""
        self.debug("Starting progress reporting loop")
        while not self.all_phases_complete_event.is_set():
            await asyncio.sleep(self.config.progress_report_interval_sec)

            for phase, stats in self.phase_stats.items():
                try:
                    await self.credit_manager.publish_progress(
                        phase, stats.sent, stats.completed
                    )
                except Exception as e:
                    self.error(f"Error publishing credit progress: {e}")
                except asyncio.CancelledError:
                    self.debug("Credit progress reporting loop cancelled")
                    return

        self.debug("All credits completed, stopping credit progress reporting loop")

start() async

Start the credit issuing strategy. This will launch the progress reporting loop, the warmup phase (if applicable), and the profiling phase, all in the background.

Source code in aiperf/services/timing_manager/credit_issuing_strategy.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
async def start(self) -> None:
    """Start the credit issuing strategy. This will launch the progress reporting loop, the
    warmup phase (if applicable), and the profiling phase, all in the background."""
    self.debug(
        lambda: f"Starting credit issuing strategy {self.__class__.__name__}"
    )
    self.all_phases_complete_event.clear()

    # Start the progress reporting loop in the background
    self.execute_async(self._progress_report_loop())

    # Execute the phases in the background
    self.execute_async(self._execute_phases())

    self.debug(
        lambda: f"Waiting for all credit phases to complete for {self.__class__.__name__}"
    )
    # Wait for all phases to complete before returning
    await self.all_phases_complete_event.wait()
    self.debug(lambda: f"All credit phases completed for {self.__class__.__name__}")

stop() async

Stop the credit issuing strategy.

Source code in aiperf/services/timing_manager/credit_issuing_strategy.py
134
135
136
async def stop(self) -> None:
    """Stop the credit issuing strategy."""
    await self.cancel_all_tasks()

CreditIssuingStrategyFactory

Bases: FactoryMixin[TimingMode, CreditIssuingStrategy]

Factory for creating credit issuing strategies based on the timing mode.

Source code in aiperf/services/timing_manager/credit_issuing_strategy.py
194
195
class CreditIssuingStrategyFactory(FactoryMixin[TimingMode, CreditIssuingStrategy]):
    """Factory for creating credit issuing strategies based on the timing mode."""

aiperf.services.timing_manager.credit_manager

CreditManagerProtocol

Bases: Protocol

Defines the interface for a CreditManager.

This is used to allow the credit issuing strategy to interact with the TimingManager in a decoupled way.

Source code in aiperf/services/timing_manager/credit_manager.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@runtime_checkable
class CreditManagerProtocol(Protocol):
    """Defines the interface for a CreditManager.

    This is used to allow the credit issuing strategy to interact with the TimingManager
    in a decoupled way.
    """

    async def drop_credit(
        self,
        credit_phase: CreditPhase,
        conversation_id: str | None = None,
        credit_drop_ns: int | None = None,
    ) -> None: ...

    async def publish_progress(
        self, phase: CreditPhase, sent: int, completed: int
    ) -> None: ...

    async def publish_credits_complete(self) -> None: ...

    async def publish_phase_start(
        self,
        phase: CreditPhase,
        start_ns: int,
        total_expected_requests: int | None,
        expected_duration_sec: float | None,
    ) -> None: ...

    async def publish_phase_sending_complete(
        self, phase: CreditPhase, sent_end_ns: int
    ) -> None: ...

    async def publish_phase_complete(
        self, phase: CreditPhase, completed: int, end_ns: int
    ) -> None: ...

CreditPhaseMessagesMixin

Bases: AsyncTaskManagerMixin, CreditPhaseMessagesRequirements

Mixin for services to implement the CreditManagerProtocol.

Requirements

This mixin must be used with a class that provides: - pub_client: PubClientProtocol - service_id: str

Source code in aiperf/services/timing_manager/credit_manager.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class CreditPhaseMessagesMixin(AsyncTaskManagerMixin, CreditPhaseMessagesRequirements):
    """Mixin for services to implement the CreditManagerProtocol.

    Requirements:
        This mixin must be used with a class that provides:
        - pub_client: PubClientProtocol
        - service_id: str
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if not isinstance(self, CreditPhaseMessagesRequirements):
            raise TypeError(
                "CreditPhaseMessagesMixin must be used with a class that provides CreditPhaseMessagesRequirements"
            )

    async def publish_phase_start(
        self,
        phase: CreditPhase,
        start_ns: int,
        total_expected_requests: int | None,
        expected_duration_sec: float | None,
    ) -> None:
        """Publish the phase start message."""
        self.execute_async(
            self.pub_client.publish(
                CreditPhaseStartMessage(
                    service_id=self.service_id,
                    phase=phase,
                    start_ns=start_ns,
                    # Only one of the below will be set, this is already validated in the strategy
                    total_expected_requests=total_expected_requests,
                    expected_duration_sec=expected_duration_sec,
                )
            )
        )

    async def publish_phase_sending_complete(
        self, phase: CreditPhase, sent_end_ns: int
    ) -> None:
        """Publish the phase sending complete message."""
        self.execute_async(
            self.pub_client.publish(
                CreditPhaseSendingCompleteMessage(
                    service_id=self.service_id,
                    phase=phase,
                    sent_end_ns=sent_end_ns,
                )
            )
        )

    async def publish_phase_complete(
        self, phase: CreditPhase, completed: int, end_ns: int
    ) -> None:
        """Publish the phase complete message."""
        self.execute_async(
            self.pub_client.publish(
                CreditPhaseCompleteMessage(
                    service_id=self.service_id,
                    phase=phase,
                    completed=completed,
                    end_ns=end_ns,
                )
            )
        )

    async def publish_progress(
        self, phase: CreditPhase, sent: int, completed: int
    ) -> None:
        """Publish the progress message."""
        self.execute_async(
            self.pub_client.publish(
                CreditPhaseProgressMessage(
                    service_id=self.service_id,
                    phase=phase,
                    sent=sent,
                    completed=completed,
                )
            )
        )

    async def publish_credits_complete(self) -> None:
        """Publish the credits complete message."""
        self.debug("Publishing credits complete message")
        self.execute_async(
            self.pub_client.publish(CreditsCompleteMessage(service_id=self.service_id))
        )

publish_credits_complete() async

Publish the credits complete message.

Source code in aiperf/services/timing_manager/credit_manager.py
152
153
154
155
156
157
async def publish_credits_complete(self) -> None:
    """Publish the credits complete message."""
    self.debug("Publishing credits complete message")
    self.execute_async(
        self.pub_client.publish(CreditsCompleteMessage(service_id=self.service_id))
    )

publish_phase_complete(phase, completed, end_ns) async

Publish the phase complete message.

Source code in aiperf/services/timing_manager/credit_manager.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
async def publish_phase_complete(
    self, phase: CreditPhase, completed: int, end_ns: int
) -> None:
    """Publish the phase complete message."""
    self.execute_async(
        self.pub_client.publish(
            CreditPhaseCompleteMessage(
                service_id=self.service_id,
                phase=phase,
                completed=completed,
                end_ns=end_ns,
            )
        )
    )

publish_phase_sending_complete(phase, sent_end_ns) async

Publish the phase sending complete message.

Source code in aiperf/services/timing_manager/credit_manager.py
108
109
110
111
112
113
114
115
116
117
118
119
120
async def publish_phase_sending_complete(
    self, phase: CreditPhase, sent_end_ns: int
) -> None:
    """Publish the phase sending complete message."""
    self.execute_async(
        self.pub_client.publish(
            CreditPhaseSendingCompleteMessage(
                service_id=self.service_id,
                phase=phase,
                sent_end_ns=sent_end_ns,
            )
        )
    )

publish_phase_start(phase, start_ns, total_expected_requests, expected_duration_sec) async

Publish the phase start message.

Source code in aiperf/services/timing_manager/credit_manager.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
async def publish_phase_start(
    self,
    phase: CreditPhase,
    start_ns: int,
    total_expected_requests: int | None,
    expected_duration_sec: float | None,
) -> None:
    """Publish the phase start message."""
    self.execute_async(
        self.pub_client.publish(
            CreditPhaseStartMessage(
                service_id=self.service_id,
                phase=phase,
                start_ns=start_ns,
                # Only one of the below will be set, this is already validated in the strategy
                total_expected_requests=total_expected_requests,
                expected_duration_sec=expected_duration_sec,
            )
        )
    )

publish_progress(phase, sent, completed) async

Publish the progress message.

Source code in aiperf/services/timing_manager/credit_manager.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
async def publish_progress(
    self, phase: CreditPhase, sent: int, completed: int
) -> None:
    """Publish the progress message."""
    self.execute_async(
        self.pub_client.publish(
            CreditPhaseProgressMessage(
                service_id=self.service_id,
                phase=phase,
                sent=sent,
                completed=completed,
            )
        )
    )

CreditPhaseMessagesRequirements

Bases: AsyncTaskManagerProtocol, AIPerfLoggerProtocol, Protocol

Requirements for the CreditPhaseMessagesMixin. This is the list of attributes that must be provided by the class that uses this mixin.

Source code in aiperf/services/timing_manager/credit_manager.py
60
61
62
63
64
65
66
67
68
@runtime_checkable
class CreditPhaseMessagesRequirements(
    AsyncTaskManagerProtocol, AIPerfLoggerProtocol, Protocol
):
    """Requirements for the CreditPhaseMessagesMixin. This is the list of attributes that must
    be provided by the class that uses this mixin."""

    pub_client: PubClientProtocol
    service_id: str

aiperf.services.timing_manager.fixed_schedule_strategy

FixedScheduleStrategy

Bases: CreditIssuingStrategy, AsyncTaskManagerMixin

Class for fixed schedule credit issuing strategy.

Source code in aiperf/services/timing_manager/fixed_schedule_strategy.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@CreditIssuingStrategyFactory.register(TimingMode.FIXED_SCHEDULE)
class FixedScheduleStrategy(CreditIssuingStrategy, AsyncTaskManagerMixin):
    """
    Class for fixed schedule credit issuing strategy.
    """

    def __init__(
        self,
        config: TimingManagerConfig,
        credit_manager: CreditManagerProtocol,
        schedule: list[tuple[int, str]],
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        self._schedule: list[tuple[int, str]] = schedule

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        # TODO: Convert this code to work with the new CreditPhase logic and base classes

        if not self._schedule:
            self.warning("No schedule loaded, no credits will be dropped")
            return

        start_time_ns = time.time_ns()

        timestamp_groups = defaultdict(list)

        for timestamp, conversation_id in self._schedule:
            timestamp_groups[timestamp].append((timestamp, conversation_id))

        schedule_unique_sorted = sorted(timestamp_groups.keys())

        for unique_timestamp in schedule_unique_sorted:
            wait_duration_ns = max(0, start_time_ns + unique_timestamp - time.time_ns())
            wait_duration_sec = wait_duration_ns / 1_000_000_000

            if wait_duration_sec > 0:
                await asyncio.sleep(wait_duration_sec)

            for _, conversation_id in timestamp_groups[unique_timestamp]:
                self.execute_async(
                    self.credit_manager.drop_credit(
                        credit_phase=CreditPhase.PROFILING,
                        conversation_id=conversation_id,
                        # We already waited, so it can be sent ASAP
                        credit_drop_ns=None,
                    )
                )

        self.info("Completed all scheduled credit drops")

aiperf.services.timing_manager.request_rate_strategy

RequestRateStrategy

Bases: CreditIssuingStrategy, AsyncTaskManagerMixin

Strategy for issuing credits based on a specified request rate.

Supports two modes: - CONSTANT: Issues credits at a constant rate with fixed intervals - POISSON: Issues credits using a Poisson process with exponentially distributed intervals

Source code in aiperf/services/timing_manager/request_rate_strategy.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@CreditIssuingStrategyFactory.register(TimingMode.REQUEST_RATE)
class RequestRateStrategy(CreditIssuingStrategy, AsyncTaskManagerMixin):
    """
    Strategy for issuing credits based on a specified request rate.

    Supports two modes:
    - CONSTANT: Issues credits at a constant rate with fixed intervals
    - POISSON: Issues credits using a Poisson process with exponentially distributed intervals
    """

    def __init__(
        self, config: TimingManagerConfig, credit_manager: CreditManagerProtocol
    ):
        super().__init__(config=config, credit_manager=credit_manager)

        if config.request_rate is None:
            raise InvalidStateError("Request rate is not set")
        if config.request_count < 1:
            raise InvalidStateError("Request count must be at least 1")

        self._request_rate = config.request_rate
        self._request_rate_mode = config.request_rate_mode

        # Initialize random number generator for reproducibility
        self._random = (
            random.Random(config.random_seed) if config.random_seed else random.Random()
        )

    async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None:
        """Execute a single phase. This will not return until the phase sending is complete."""
        # Issue credit drops at the specified rate
        if self._request_rate_mode == RequestRateMode.CONSTANT:
            await self._execute_constant_rate(phase_stats)
        elif self._request_rate_mode == RequestRateMode.POISSON:
            await self._execute_poisson_rate(phase_stats)
        else:
            raise InvalidStateError(
                f"Unsupported request rate mode: {self._request_rate_mode}"
            )

    async def _execute_constant_rate(self, phase_stats: CreditPhaseStats) -> None:
        """Execute credit drops at a constant rate."""

        # The effective time between each credit drop is the inverse of the request rate.
        period_sec = 1.0 / self._request_rate

        # We start by sending the first credit immediately.
        next_drop_at = time.perf_counter()

        while phase_stats.should_send:
            wait_sec = next_drop_at - time.perf_counter()
            if wait_sec > 0:
                await asyncio.sleep(wait_sec)

            self.execute_async(
                self.credit_manager.drop_credit(credit_phase=phase_stats.type)
            )
            phase_stats.sent += 1

            # Instead of naively sleeping for a constant period_sec, we are scheduling the
            # next drop to happen at exactly (next_drop_at + period_sec). This ensures that
            # we do not slowly drift over time based on slight variances in the asyncio.sleep
            # or executing the credit drop task.
            next_drop_at += period_sec

    async def _execute_poisson_rate(self, phase_stats: CreditPhaseStats) -> None:
        """Execute credit drops using Poisson process (exponential inter-arrival times).

        In a Poisson process with rate λ (requests per second), the inter-arrival times
        are exponentially distributed with parameter λ. This models realistic traffic
        patterns where requests arrive randomly but at a consistent average rate.
        """
        while phase_stats.should_send:
            # For Poisson process, inter-arrival times are exponentially distributed.
            # random.expovariate(lambd) generates exponentially distributed random numbers
            # where lambd is the rate parameter (requests per second)
            wait_duration_sec = self._random.expovariate(self._request_rate)

            if wait_duration_sec > 0:
                await asyncio.sleep(wait_duration_sec)

            self.execute_async(
                self.credit_manager.drop_credit(credit_phase=phase_stats.type)
            )
            phase_stats.sent += 1

aiperf.services.timing_manager.timing_manager

TimingManager

Bases: BaseComponentService, CreditPhaseMessagesMixin

The TimingManager service is responsible to generate the schedule and issuing timing credits for requests.

Source code in aiperf/services/timing_manager/timing_manager.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@ServiceFactory.register(ServiceType.TIMING_MANAGER)
class TimingManager(BaseComponentService, CreditPhaseMessagesMixin):
    """
    The TimingManager service is responsible to generate the schedule and issuing
    timing credits for requests.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.debug("Initializing timing manager")

        self.dataset_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommunicationClientAddressType.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )
        self.credit_drop_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommunicationClientAddressType.CREDIT_DROP,
                bind=True,
            )
        )
        self.credit_return_pull_client: PullClientProtocol = (
            self.comms.create_pull_client(
                CommunicationClientAddressType.CREDIT_RETURN,
                bind=True,
            )
        )

        self._credit_issuing_strategy: CreditIssuingStrategy | None = None

    @property
    def service_type(self) -> ServiceType:
        """The type of service."""
        return ServiceType.TIMING_MANAGER

    @on_init
    async def _timing_manager_initialize(self) -> None:
        """Initialize timing manager-specific components."""
        self.debug("Initializing timing manager")
        self.config = TimingManagerConfig.from_user_config(self.user_config)
        await self.credit_return_pull_client.register_pull_callback(
            message_type=MessageType.CREDIT_RETURN,
            callback=self._on_credit_return,
        )

    @on_configure
    async def _timing_manager_configure(self, message: CommandMessage) -> None:
        """Configure the timing manager."""
        self.debug(lambda: f"Configuring timing manager with message: {message}")

        if self.config.timing_mode == TimingMode.FIXED_SCHEDULE:
            # This will block until the dataset is ready and the timing response is received
            dataset_timing_response: DatasetTimingResponse = (
                await self.dataset_request_client.request(
                    message=DatasetTimingRequest(
                        service_id=self.service_id,
                    ),
                )
            )
            self.debug(
                lambda: f"TM: Received dataset timing response: {dataset_timing_response}"
            )
            self.info("TM: Using fixed schedule strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.FIXED_SCHEDULE,
                    config=self.config,
                    credit_manager=self,
                    schedule=dataset_timing_response.timing_data,
                )
            )
        elif self.config.timing_mode == TimingMode.CONCURRENCY:
            self.info("TM: Using concurrency strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.CONCURRENCY,
                    config=self.config,
                    credit_manager=self,
                )
            )
        elif self.config.timing_mode == TimingMode.REQUEST_RATE:
            self.info("TM: Using request rate strategy")
            self._credit_issuing_strategy = (
                CreditIssuingStrategyFactory.create_instance(
                    TimingMode.REQUEST_RATE,
                    config=self.config,
                    credit_manager=self,
                )
            )

        if not self._credit_issuing_strategy:
            raise InvalidStateError("No credit issuing strategy configured")

    @on_start
    async def _timing_manager_start(self) -> None:
        """Start the timing manager and issue credit drops according to the configured strategy."""
        self.debug("Starting timing manager")

        if not self._credit_issuing_strategy:
            raise InvalidStateError("No credit issuing strategy configured")

        await asyncio.sleep(2)
        self.execute_async(self._credit_issuing_strategy.start())

    @on_stop
    async def _timing_manager_stop(self) -> None:
        """Stop the timing manager."""
        self.debug("Stopping timing manager")
        if self._credit_issuing_strategy:
            await self._credit_issuing_strategy.stop()
        await self.cancel_all_tasks()

    async def _on_credit_return(self, message: CreditReturnMessage) -> None:
        """Handle the credit return message."""
        self.debug(lambda: f"Timing manager received credit return message: {message}")
        if self._credit_issuing_strategy:
            await self._credit_issuing_strategy._on_credit_return(message)

    async def drop_credit(
        self,
        credit_phase: CreditPhase,
        conversation_id: str | None = None,
        credit_drop_ns: int | None = None,
    ) -> None:
        """Drop a credit."""
        self.execute_async(
            self.credit_drop_push_client.push(
                message=CreditDropMessage(
                    service_id=self.service_id,
                    phase=credit_phase,
                    credit_drop_ns=credit_drop_ns,
                    conversation_id=conversation_id,
                ),
            )
        )

service_type property

The type of service.

drop_credit(credit_phase, conversation_id=None, credit_drop_ns=None) async

Drop a credit.

Source code in aiperf/services/timing_manager/timing_manager.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
async def drop_credit(
    self,
    credit_phase: CreditPhase,
    conversation_id: str | None = None,
    credit_drop_ns: int | None = None,
) -> None:
    """Drop a credit."""
    self.execute_async(
        self.credit_drop_push_client.push(
            message=CreditDropMessage(
                service_id=self.service_id,
                phase=credit_phase,
                credit_drop_ns=credit_drop_ns,
                conversation_id=conversation_id,
            ),
        )
    )

main()

Main entry point for the timing manager.

Source code in aiperf/services/timing_manager/timing_manager.py
194
195
196
197
198
def main() -> None:
    """Main entry point for the timing manager."""
    from aiperf.common.bootstrap import bootstrap_and_run_service

    bootstrap_and_run_service(TimingManager)

aiperf.services.workers.credit_processor_mixin

CreditProcessorMixin

Bases: CreditProcessorMixinRequirements

CreditProcessorMixin is a mixin that provides a method to process credit drops.

Source code in aiperf/services/workers/credit_processor_mixin.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class CreditProcessorMixin(CreditProcessorMixinRequirements):
    """CreditProcessorMixin is a mixin that provides a method to process credit drops."""

    def __init__(self, **kwargs):
        if not isinstance(self, CreditProcessorMixinRequirements):
            raise ValueError(
                "CreditProcessorMixin must be used with CreditProcessorMixinRequirements"
            )

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage.

        - Every credit must be returned after processing
        - All results or errors should be converted to a RequestRecord and pushed to the inference results client.

        NOTE: This function MUST NOT return until the credit drop is fully processed.
        This is to ensure that the max concurrency is respected via the semaphore of the pull client.
        The way this is enforced is by requiring that this method returns a CreditReturnMessage.
        """
        # TODO: Add tests to ensure that the above note is never violated in the future

        self.trace(lambda: f"Processing credit drop: {message}")
        drop_perf_ns = time.perf_counter_ns()  # The time the credit was received

        if message.phase not in self.task_stats:
            self.task_stats[message.phase] = WorkerPhaseTaskStats()
        self.task_stats[message.phase].total += 1

        record: RequestRecord = RequestRecord()
        try:
            record = await self._execute_single_credit_internal(message)

        except Exception as e:
            self.exception(f"Error processing credit drop: {e}")
            record.error = ErrorDetails.from_exception(e)
            record.end_perf_ns = time.perf_counter_ns()

        finally:
            record.credit_phase = message.phase
            msg = InferenceResultsMessage(
                service_id=self.service_id,
                record=record,
            )

            # Note that we already ensured that the phase exists in the task_stats dict in the above code.
            if not record.valid:
                self.task_stats[message.phase].failed += 1
            else:
                self.task_stats[message.phase].completed += 1

            try:
                await self.inference_results_push_client.push(msg)
            except Exception as e:
                # If we fail to push the record, log the error and continue
                self.exception(f"Error pushing request record: {e}")
            finally:
                # Calculate the latency of the credit drop (from when the credit was dropped to when the request was sent)
                pre_inference_ns = record.start_perf_ns - drop_perf_ns
                # Always return the credits
                return_message = CreditReturnMessage(
                    service_id=self.service_id,
                    delayed_ns=record.delayed_ns,
                    pre_inference_ns=pre_inference_ns,
                    phase=message.phase,
                )
                self.trace(lambda: f"Returning credit {return_message}")
                return return_message  # noqa: B012

    async def _execute_single_credit_internal(
        self, message: CreditDropMessage
    ) -> RequestRecord:
        """Run a credit task for a single credit."""

        if not self.inference_client:
            raise NotInitializedError("Inference server client not initialized.")

        # retrieve the prompt from the dataset
        conversation_response: ConversationResponseMessage = (
            await self.conversation_request_client.request(
                ConversationRequestMessage(
                    service_id=self.service_id,
                    conversation_id=message.conversation_id,
                    credit_phase=message.phase,
                )
            )
        )
        self.trace(lambda: f"Received response message: {conversation_response}")

        if isinstance(conversation_response, ErrorMessage):
            return RequestRecord(
                model_name=self.model_endpoint.primary_model_name,
                conversation_id=message.conversation_id,
                turn_index=0,
                timestamp_ns=time.time_ns(),
                start_perf_ns=time.perf_counter_ns(),
                end_perf_ns=time.perf_counter_ns(),
                error=conversation_response.error,
            )

        record = await self._call_inference_api_internal(
            message, conversation_response.conversation.turns[0]
        )
        record.model_name = self.model_endpoint.primary_model_name
        record.conversation_id = conversation_response.conversation.session_id
        record.turn_index = 0
        return record

    async def _call_inference_api_internal(
        self,
        message: CreditDropMessage,
        turn: Turn,
    ) -> RequestRecord:
        """Make a single call to the inference API. Will return an error record if the call fails."""
        self.trace(lambda: f"Calling inference API for turn: {turn}")
        formatted_payload = None
        pre_send_perf_ns = None
        timestamp_ns = None
        try:
            # Format payload for the API request
            formatted_payload = await self.request_converter.format_payload(
                model_endpoint=self.model_endpoint,
                turn=turn,
            )

            # NOTE: Current implementation of the TimingManager bypasses this, it is for future use.
            # Wait for the credit drop time if it is in the future.
            # Note that we check this after we have retrieved the data from the dataset, to ensure
            # that we are fully ready to go.
            delayed_ns = None
            drop_ns = message.credit_drop_ns
            now_ns = time.time_ns()
            if drop_ns and drop_ns > now_ns:
                self.trace(
                    lambda: f"Waiting for credit drop expected time: {(drop_ns - now_ns) / NANOS_PER_SECOND:.2f} s"
                )
                await asyncio.sleep((drop_ns - now_ns) / NANOS_PER_SECOND)
            elif drop_ns and drop_ns < now_ns:
                delayed_ns = now_ns - drop_ns

            # Save the current perf_ns before sending the request so it can be used to calculate
            # the start_perf_ns of the request in case of an exception.
            pre_send_perf_ns = time.perf_counter_ns()
            timestamp_ns = time.time_ns()

            # Send the request to the Inference Server API and wait for the response
            result: RequestRecord = await self.inference_client.send_request(
                model_endpoint=self.model_endpoint,
                payload=formatted_payload,
            )

            self.debug(
                lambda: f"pre_send_perf_ns to start_perf_ns latency: {result.start_perf_ns - pre_send_perf_ns} ns"
            )

            result.delayed_ns = delayed_ns
            return result

        except Exception as e:
            self.exception(
                f"Error calling inference server API at {self.model_endpoint.url}: {e}"
            )
            return RequestRecord(
                request=formatted_payload,
                timestamp_ns=timestamp_ns or time.time_ns(),
                # Try and use the pre_send_perf_ns if it is available, otherwise use the current time.
                start_perf_ns=pre_send_perf_ns or time.perf_counter_ns(),
                end_perf_ns=time.perf_counter_ns(),
                error=ErrorDetails.from_exception(e),
            )

CreditProcessorMixinRequirements

Bases: AIPerfLoggerProtocol, Protocol

CreditProcessorMixinRequirements is a protocol that provides the requirements needed for the CreditProcessorMixin.

Source code in aiperf/services/workers/credit_processor_mixin.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@runtime_checkable
class CreditProcessorMixinRequirements(AIPerfLoggerProtocol, Protocol):
    """CreditProcessorMixinRequirements is a protocol that provides the requirements needed for the CreditProcessorMixin."""

    service_id: str
    inference_client: InferenceClientProtocol
    conversation_request_client: RequestClientProtocol
    inference_results_push_client: PushClientProtocol
    request_converter: RequestConverterProtocol
    model_endpoint: ModelEndpointInfo
    task_stats: dict[CreditPhase, WorkerPhaseTaskStats]

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage."""
        ...

    async def _execute_single_credit_internal(
        self, message: CreditDropMessage
    ) -> RequestRecord:
        """Execute a single credit drop. Return the RequestRecord."""
        ...

    async def _call_inference_api_internal(
        self,
        message: CreditDropMessage,
        turn: Turn,
    ) -> RequestRecord:
        """Make a single call to the inference API. Will return an error record if the call fails."""
        ...

CreditProcessorProtocol

Bases: Protocol

CreditProcessorProtocol is a protocol that provides a method to process credit drops.

Source code in aiperf/services/workers/credit_processor_mixin.py
29
30
31
32
33
34
35
36
37
@runtime_checkable
class CreditProcessorProtocol(Protocol):
    """CreditProcessorProtocol is a protocol that provides a method to process credit drops."""

    async def _process_credit_drop_internal(
        self, message: CreditDropMessage
    ) -> CreditReturnMessage:
        """Process a credit drop message. Return the CreditReturnMessage."""
        ...

aiperf.services.workers.worker

Worker

Bases: BaseComponentService, ProcessHealthMixin, CreditProcessorMixin

Worker is primarily responsible for making API calls to the inference server. It also manages the conversation between turns and returns the results to the Inference Results Parsers.

Source code in aiperf/services/workers/worker.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@ServiceFactory.register(ServiceType.WORKER)
class Worker(BaseComponentService, ProcessHealthMixin, CreditProcessorMixin):
    """Worker is primarily responsible for making API calls to the inference server.
    It also manages the conversation between turns and returns the results to the Inference Results Parsers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
        **kwargs,
    ):
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
            **kwargs,
        )

        self.debug(lambda: f"Initializing worker process: {self.process.pid}")

        self.health_check_interval = (
            self.service_config.workers.health_check_interval_seconds
        )

        self.task_stats: dict[CreditPhase, WorkerPhaseTaskStats] = {}

        self.credit_drop_pull_client: PullClientProtocol = (
            self.comms.create_pull_client(
                CommunicationClientAddressType.CREDIT_DROP,
            )
        )
        self.credit_return_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommunicationClientAddressType.CREDIT_RETURN,
            )
        )
        self.inference_results_push_client: PushClientProtocol = (
            self.comms.create_push_client(
                CommunicationClientAddressType.RAW_INFERENCE_PROXY_FRONTEND,
            )
        )
        self.conversation_request_client: RequestClientProtocol = (
            self.comms.create_request_client(
                CommunicationClientAddressType.DATASET_MANAGER_PROXY_FRONTEND,
            )
        )

        self.model_endpoint = ModelEndpointInfo.from_user_config(self.user_config)

        self.debug(
            lambda: f"Creating inference client for {self.model_endpoint.endpoint.type}, "
            f"class: {InferenceClientFactory.get_class_from_type(self.model_endpoint.endpoint.type).__name__}",
        )
        self.request_converter = RequestConverterFactory.create_instance(
            self.model_endpoint.endpoint.type,
        )
        self.inference_client = InferenceClientFactory.create_instance(
            self.model_endpoint.endpoint.type,
            model_endpoint=self.model_endpoint,
        )

    @property
    def service_type(self) -> ServiceType:
        return ServiceType.WORKER

    @on_init
    async def _initialize_worker(self) -> None:
        self.debug("Initializing worker")

        await self.credit_drop_pull_client.register_pull_callback(
            MessageType.CREDIT_DROP, self._credit_drop_callback
        )

        self.debug("Worker initialized")

    @on_configure
    async def _configure_worker(self, message: CommandMessage) -> None:
        # NOTE: Right now we are configuring the worker in the __init__ method,
        #       but that may change based on how we implement sweeps.
        pass

    async def _credit_drop_callback(self, message: CreditDropMessage) -> None:
        """Handle an incoming credit drop message. Every credit must be returned after processing."""

        # Create a default credit return message in case of an exception
        credit_return_message = CreditReturnMessage(
            service_id=self.service_id,
            phase=message.phase,
        )

        try:
            # NOTE: This must be awaited to ensure that the max concurrency is respected
            credit_return_message = await self._process_credit_drop_internal(message)
        except Exception as e:
            self.exception(f"Error processing credit drop: {e}")
        finally:
            # It is fine to execute the push asynchronously here because the worker is technically
            # ready to process the next credit drop.
            self.execute_async(
                self.credit_return_push_client.push(credit_return_message)
            )

    @on_stop
    async def _shutdown_worker(self) -> None:
        self.debug("Shutting down worker")
        if self.inference_client:
            await self.inference_client.close()

    @aiperf_task
    async def _health_check_task(self) -> None:
        """Task to report the health of the worker to the worker manager."""
        while True:
            try:
                health_message = self.create_health_message()
                await self.pub_client.publish(health_message)
            except Exception as e:
                self.exception(f"Error reporting health: {e}")
            except asyncio.CancelledError:
                self.debug("Health check task cancelled")
                break

            await asyncio.sleep(self.health_check_interval)

    def create_health_message(self) -> WorkerHealthMessage:
        return WorkerHealthMessage(
            service_id=self.service_id,
            process=self.get_process_health(),
            task_stats=self.task_stats,
        )

aiperf.services.workers.worker_manager

WorkerManager

Bases: BaseComponentService

The WorkerManager service is primary responsibility to manage the worker processes. It will spawn the workers, monitor their health, and stop them when the service is stopped. In the future it will also be responsible for the auto-scaling of the workers.

Source code in aiperf/services/workers/worker_manager.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@ServiceFactory.register(ServiceType.WORKER_MANAGER)
class WorkerManager(BaseComponentService):
    """
    The WorkerManager service is primary responsibility to manage the worker processes.
    It will spawn the workers, monitor their health, and stop them when the service is stopped.
    In the future it will also be responsible for the auto-scaling of the workers.
    """

    def __init__(
        self,
        service_config: ServiceConfig,
        user_config: UserConfig | None = None,
        service_id: str | None = None,
    ) -> None:
        super().__init__(
            service_config=service_config,
            user_config=user_config,
            service_id=service_id,
        )
        self.trace("WorkerManager.__init__")
        self.workers: dict[str, WorkerProcessInfo] = {}
        self.worker_health: dict[str, WorkerHealthMessage] = {}

        self.cpu_count = multiprocessing.cpu_count()
        self.debug("Detected %s CPU cores/threads", self.cpu_count)

        self.max_concurrency = self.user_config.loadgen.concurrency
        self.max_workers = self.service_config.workers.max
        if self.max_workers is None:
            # Default to the number of CPU cores - 1
            self.max_workers = self.cpu_count - 1

        # Cap the worker count to the max concurrency + 1, but only if the user is in concurrency mode.
        if self.max_concurrency > 1:
            self.max_workers = min(
                self.max_concurrency + 1,
                self.max_workers,
            )

        # Ensure we have at least the min workers
        self.max_workers = max(
            self.max_workers,
            self.service_config.workers.min or 0,
        )
        self.initial_workers = self.max_workers

    @property
    def service_type(self) -> ServiceType:
        return ServiceType.WORKER_MANAGER

    @on_init
    async def _initialize(self) -> None:
        """Initialize worker manager-specific components."""
        self.debug("WorkerManager initializing")

        await self.sub_client.subscribe(
            MessageType.WORKER_HEALTH, self._on_worker_health
        )

        # Spawn workers
        # TODO: This logic can be refactored to make use of the ServiceManager class
        if self.service_config.service_run_type == ServiceRunType.MULTIPROCESSING:
            await self._spawn_multiprocessing_workers()

        elif self.service_config.service_run_type == ServiceRunType.KUBERNETES:
            await self._spawn_kubernetes_workers()

        else:
            raise ConfigurationError(
                f"Unsupported run type: {self.service_config.service_run_type}",
            )

    async def _on_worker_health(self, message: WorkerHealthMessage) -> None:
        self.debug("Received worker health message: %s", message)
        self.worker_health[message.service_id] = message

    @on_stop
    async def _stop(self) -> None:
        self.debug("WorkerManager stopping")

        # Stop all workers
        # TODO: This logic can be refactored to make use of the ServiceManager class
        if self.service_config.service_run_type == ServiceRunType.MULTIPROCESSING:
            await self._stop_multiprocessing_workers()
        elif self.service_config.service_run_type == ServiceRunType.KUBERNETES:
            await self._stop_kubernetes_workers()
        else:
            raise ConfigurationError(
                f"Unsupported run type: {self.service_config.service_run_type}",
            )

    @on_cleanup
    async def _cleanup(self) -> None:
        self.debug("WorkerManager cleaning up")
        self.workers.clear()

    async def _spawn_kubernetes_workers(self) -> None:
        self.debug("Spawning %s worker pods", self.initial_workers)
        # TODO: Implement Kubernetes start
        raise NotImplementedError("Kubernetes start not implemented")

    async def _stop_kubernetes_workers(self) -> None:
        self.debug("Stopping all worker processes")
        # TODO: Implement Kubernetes stop
        raise NotImplementedError("Kubernetes stop not implemented")

    async def _spawn_multiprocessing_workers(self) -> None:
        self.debug("Spawning %s worker processes", self.initial_workers)

        # Get the global log queue for child process logging
        from aiperf.common.logging import get_global_log_queue

        log_queue = get_global_log_queue()

        for _ in range(self.initial_workers):
            worker_id = f"worker_{uuid.uuid4().hex[:8]}"

            process = Process(
                target=bootstrap_and_run_service,
                name=f"{worker_id}_process",
                kwargs={
                    "service_class": Worker,
                    "service_config": self.service_config,
                    "user_config": self.user_config,
                    "log_queue": log_queue,
                    "service_id": worker_id,
                },
                daemon=True,
            )
            process.start()

            self.workers[worker_id] = WorkerProcessInfo(
                worker_id=worker_id,
                process=process,
            )
            self.debug(
                lambda id=worker_id,
                pid=process.pid: f"Started worker process {id} (pid: {pid})"
            )

    async def _stop_multiprocessing_workers(self) -> None:
        self.debug("Stopping all worker processes")

        # First terminate all processes
        for worker_id, worker_info in self.workers.items():
            self.debug(
                lambda id=worker_id,
                pid=worker_info.process.pid: f"Stopping worker process {id} (pid: {pid})"
            )
            process = worker_info.process
            if process and process.is_alive():
                self.debug(
                    lambda id=worker_id,
                    pid=process.pid: f"Terminating worker process {id} (pid: {pid})"
                )
                process.terminate()

        # Then wait for all to finish
        await asyncio.gather(
            *[
                self._wait_for_process(worker_id, worker_info.process)
                for worker_id, worker_info in self.workers.items()
                if worker_info.process
            ]
        )

        self.debug("All worker processes stopped")

    async def _wait_for_process(self, worker_id: str, process: Process) -> None:
        """Wait for a process to terminate with timeout handling."""
        try:
            await asyncio.to_thread(process.join, timeout=TASK_CANCEL_TIMEOUT_SHORT)
            self.debug(
                lambda id=worker_id,
                pid=process.pid: f"Worker process {id} (pid: {pid}) stopped"
            )
        except asyncio.TimeoutError:
            self.warning(
                lambda id=worker_id,
                pid=process.pid: f"Worker process {id} (pid: {pid}) did not terminate gracefully, killing"
            )
            process.kill()

WorkerProcessInfo

Bases: AIPerfBaseModel

Information about a worker process.

Source code in aiperf/services/workers/worker_manager.py
29
30
31
32
33
34
35
class WorkerProcessInfo(AIPerfBaseModel):
    """Information about a worker process."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    worker_id: str = Field(..., description="ID of the worker process")
    process: Any = Field(None, description="Process object or task")